Skip to content

fix(networks): only instantiate cross-attention layers when with_cross_attention=True#8873

Open
AlexanderSanin wants to merge 1 commit into
Project-MONAI:devfrom
AlexanderSanin:fix/transformer-block-cross-attention-always-instantiated
Open

fix(networks): only instantiate cross-attention layers when with_cross_attention=True#8873
AlexanderSanin wants to merge 1 commit into
Project-MONAI:devfrom
AlexanderSanin:fix/transformer-block-cross-attention-always-instantiated

Conversation

@AlexanderSanin
Copy link
Copy Markdown
Contributor

Summary

Fixes #8845

TransformerBlock was unconditionally creating norm_cross_attn and cross_attn in __init__ regardless of the with_cross_attention flag. This caused unused parameters to appear in model.parameters() / model.named_parameters(), increasing model size and producing confusing "no gradient" warnings during training.

Root cause — two layers created unconditionally:

# always ran, even when with_cross_attention=False
self.norm_cross_attn = nn.LayerNorm(hidden_size)
self.cross_attn = CrossAttentionBlock(...)

Fix — guard instantiation with the same flag already used in forward:

if with_cross_attention:
    self.norm_cross_attn = nn.LayerNorm(hidden_size)
    self.cross_attn = CrossAttentionBlock(...)

Changes

  • monai/networks/blocks/transformerblock.py: wrap cross-attention layer init in if with_cross_attention:
  • tests/networks/blocks/test_transformerblock.py: add two regression tests verifying parameter presence/absence

Test plan

  • test_no_cross_attention_params_when_disabled — asserts no cross_attn/norm_cross_attn params when with_cross_attention=False
  • test_cross_attention_params_when_enabled — asserts cross-attn params exist when with_cross_attention=True
  • test_ill_arg — existing validation tests unchanged
  • Existing parameterized test_shape tests cover both with_cross_attention=True/False forward passes

https://claude.ai/code/session_01LV2dy8NFh3smu9f2RfgFvs

…s_attention=True

TransformerBlock was unconditionally creating norm_cross_attn and cross_attn
even when with_cross_attention=False, causing unused parameters to appear in
model.parameters() and named_parameters(). Guard the instantiation with an
if-block so those layers only exist when they are actually used.

Closes Project-MONAI#8845

Signed-off-by: Oleksandr Sanin <alexaaander.sanin@gmail.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 26, 2026

📝 Walkthrough

Walkthrough

TransformerBlock now conditionally instantiates norm_cross_attn and cross_attn submodules only when with_cross_attention=True. Previously these modules were unconditionally created despite being used only when the flag is enabled, resulting in unused parameters appearing in model.parameters(). Two new tests verify that cross-attention parameters are absent when disabled and present when enabled.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~8 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed Title clearly and concisely summarizes the main change: conditional instantiation of cross-attention layers based on the flag.
Description check ✅ Passed Description is comprehensive with root cause, fix, changes, and test plan. Follows template structure and provides clear context.
Linked Issues check ✅ Passed Changes directly address issue #8845: conditional initialization of cross-attention layers in init prevents unused parameters when with_cross_attention=False.
Out of Scope Changes check ✅ Passed All changes are in-scope: modifications to TransformerBlock.init and regression tests directly address the linked issue.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@AlexanderSanin
Copy link
Copy Markdown
Contributor Author

Hey @ericspod @aymuos15. Could you, please, have a look at this?

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
tests/networks/blocks/test_transformerblock.py (2)

67-74: ⚡ Quick win

Strengthen the enabled-case parameter assertion.

any("cross_attn" in n ...) also matches norm_cross_attn; this can pass without proving cross_attn.* exists.

Proposed test tightening
     self.assertTrue(
-        any("cross_attn" in n for n in param_names),
-        "Expected cross-attention parameters not found when with_cross_attention=True",
+        any(n.startswith("cross_attn.") for n in param_names),
+        "Expected cross_attn parameters not found when with_cross_attention=True",
+    )
+    self.assertTrue(
+        any(n.startswith("norm_cross_attn.") for n in param_names),
+        "Expected norm_cross_attn parameters not found when with_cross_attention=True",
     )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/networks/blocks/test_transformerblock.py` around lines 67 - 74, The
assertion in test_cross_attention_params_when_enabled is too permissive because
any("cross_attn" in n ...) matches names like norm_cross_attn; update the check
to assert presence of actual cross-attention parameter namespace by matching a
stricter pattern (e.g., names that start with "cross_attn" or contain
".cross_attn."/"cross_attn."), e.g., replace the any(...) with a predicate that
uses n.startswith("cross_attn") or checks for ".cross_attn" to ensure
TransformerBlock(with_cross_attention=True) truly registers parameters under the
cross_attn module namespace.

57-58: ⚡ Quick win

Use Google-style docstrings for the new test methods.

Current one-line docstrings don’t follow the required structured format.

As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."

Also applies to: 67-68

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/networks/blocks/test_transformerblock.py` around lines 57 - 58, Replace
the one-line docstrings in the test methods with Google-style docstrings: for
the test_no_cross_attention_params_when_disabled (and the adjacent test at the
following definition), expand the docstring to a short summary line followed by
sections "Args:", "Returns:", and "Raises:" (even if empty) and describe the
test inputs (if any), the expected outcome (no cross-attention parameters
registered), and that the test returns None; keep the summary concise and use
the function names test_no_cross_attention_params_when_disabled and the adjacent
test to locate and update the docstrings accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@tests/networks/blocks/test_transformerblock.py`:
- Around line 67-74: The assertion in test_cross_attention_params_when_enabled
is too permissive because any("cross_attn" in n ...) matches names like
norm_cross_attn; update the check to assert presence of actual cross-attention
parameter namespace by matching a stricter pattern (e.g., names that start with
"cross_attn" or contain ".cross_attn."/"cross_attn."), e.g., replace the
any(...) with a predicate that uses n.startswith("cross_attn") or checks for
".cross_attn" to ensure TransformerBlock(with_cross_attention=True) truly
registers parameters under the cross_attn module namespace.
- Around line 57-58: Replace the one-line docstrings in the test methods with
Google-style docstrings: for the test_no_cross_attention_params_when_disabled
(and the adjacent test at the following definition), expand the docstring to a
short summary line followed by sections "Args:", "Returns:", and "Raises:" (even
if empty) and describe the test inputs (if any), the expected outcome (no
cross-attention parameters registered), and that the test returns None; keep the
summary concise and use the function names
test_no_cross_attention_params_when_disabled and the adjacent test to locate and
update the docstrings accordingly.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 88e3a7d4-470d-4b08-968c-6ac5fb14919a

📥 Commits

Reviewing files that changed from the base of the PR and between 0a8d945 and b174263.

📒 Files selected for processing (2)
  • monai/networks/blocks/transformerblock.py
  • tests/networks/blocks/test_transformerblock.py

@AlexanderSanin
Copy link
Copy Markdown
Contributor Author

The failing CI checks are caused by a transient upstream dependency regression in cupy-cuda12x and are not related to this PR's changes.

Root cause: cupy-cuda12x 14.1.0 (released between May 22 and May 26) added an unconditional import pytest at module load time in cupy/testing/_random.py. pytest is not installed in the target environments for several CI jobs (only in requirements-dev.txt), so any import of monai.networks (which pulls in polygraphy.backend.commonpolygraphy.util.utilcupy.testing._random) triggers a ModuleNotFoundError.

Evidence:

  • Last successful tests workflow on dev (commit 0a8d945, run #26286623163, May 22): installed cupy-cuda12x-14.0.1 → all jobs green
  • This PR (same base commit 0a8d945, May 26): installs cupy-cuda12x-14.1.0full-dep, static-checks (pytype) fail; mypy and codeformat cancelled during dep install timeout
  • The companion PR fix(losses): register buffers in GlobalMutualInformationLoss #8872 (different changes entirely) shows the exact same failure pattern, confirming it is environment-driven

Fix scope: This is a separate issue from this PR. The cleanest mitigation would be to pin cupy-cuda12x!=14.0.0,!=14.1.0 in requirements-dev.txt (or add pytest to the runtime-import path's environment). Happy to open a separate PR for that if maintainers agree.

This PR's tests pass locally:

test_no_cross_attention_params_when_disabled PASSED
test_cross_attention_params_when_enabled     PASSED
test_ill_arg                                 PASSED
test_access_attn_matrix                      PASSED

@AlexanderSanin
Copy link
Copy Markdown
Contributor Author

AlexanderSanin commented May 26, 2026

Status & Blockers

All code changes are complete and tests pass locally (4/4). Two things are blocking merge:

1. CI Failures (external regression — not our code)

The failing CI jobs (full-dep (ubuntu-latest), static-checks (pytype), mypy, codeformat) are caused by a cupy-cuda12x 14.1.0 regression released on May 26, 2026. That version introduced import pytest at module load time in cupy/testing/_random.py, which breaks any environment where pytest is not installed.

The failure chain is:
monai.networkstrt_compilerpolygraphy.backend.commonpolygraphy.util.utilcupy.testing._randomimport pytestModuleNotFoundError

This is pre-existing on the dev branch — the last fully green dev CI run (#26286623163) was on May 22 with cupy 14.0.1. All our PR-specific tests pass.

2. Awaiting Required Review

This PR needs at least one approving review from a code owner before it can be merged.

@KumoLiu @ericspod @Nic-Ma — could one of you take a look when you get a chance? The change is in monai/networks/blocks/transformerblock.py and guards cross-attention layer instantiation behind if with_cross_attention: so that models initialized with with_cross_attention=False don't allocate unused parameters. Thanks!

Note on backwards compatibility: Models saved with the old code that had with_cross_attention=False will have cross_attn.* keys in their state_dict. Loading those checkpoints with the fixed code requires strict=False.

@ericspod
Copy link
Copy Markdown
Member

Status & Blockers

All code changes are complete and tests pass locally (4/4). Two things are blocking merge:

1. CI Failures (external regression — not our code)

I'm seeing the same issue and am working on a fix right now on a PR I hope to integrate soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

TransformerBlock instantiates CrossAttentionBlock even when with_cross_attention=False

2 participants