Skip to content

Allow rel_pos_embedding with use_flash_attention in SABlock and CrossAttentionBlock#8842

Open
aymuos15 wants to merge 3 commits intoProject-MONAI:devfrom
aymuos15:fix-7997-flash-relpos
Open

Allow rel_pos_embedding with use_flash_attention in SABlock and CrossAttentionBlock#8842
aymuos15 wants to merge 3 commits intoProject-MONAI:devfrom
aymuos15:fix-7997-flash-relpos

Conversation

@aymuos15
Copy link
Copy Markdown
Contributor

@aymuos15 aymuos15 commented May 4, 2026

Fixes #7997.

Lift the hard ValueError that prevented combining rel_pos_embedding with use_flash_attention=True in SABlock and CrossAttentionBlock. Issue #7997 tracks a suggestion originally raised in PR #7977 review comment: the relative-position bias can be routed through the additive attn_mask argument of torch.nn.functional.scaled_dot_product_attention (SDPA), at the cost of dropping out of the true flash kernel fast path.

How it works

  • When no rel-pos bias / causal mask / user mask is needed, attn_mask=None is passed to SDPA so PyTorch can still dispatch the true flash kernel — the existing fast path is preserved.
  • When a bias is needed, it is built once and passed as an additive float tensor via attn_mask. SDPA falls back to a backend that supports an additive float bias (typically the memory-efficient backend, or the math backend as a universal fallback). This is the trade-off acknowledged in the issue: not the real flash kernel, but still meaningfully faster and lower-memory than the explicit QKᵀ → softmax → V path.
  • causal=True combined with a bias is handled by converting the boolean causal_mask to additive -inf bias and disabling SDPA's is_causal (since SDPA cannot combine is_causal=True with a custom attn_mask). Pure causal (no bias, no user mask) still uses is_causal=True and the optimised path.
  • save_attn=True continues to raise ValueError when combined with use_flash_attention=True, since SDPA does not expose the explicit attention matrix.

Tests

The previous "raises ValueError" tests are replaced with numerical-equivalence tests (assert_allclose(out_flash, out_ref, atol=1e-4)) for both 2D (16, 32) and 3D (8, 8, 8) input_size, comparing the flash path against the explicit attention path with shared weights. Same coverage added for CrossAttentionBlock.

Docstrings for use_flash_attention are updated to clarify that backend selection is delegated to SDPA.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • New tests added to cover the changes.
  • In-line docstrings updated.

…roject-MONAI#7997)

Lift the hard `ValueError` that prevented combining `rel_pos_embedding` with
`use_flash_attention=True` in `SABlock` and `CrossAttentionBlock`. When a
relative-position bias (and/or causal mask) is present, build an additive
attention bias and pass it via `attn_mask` to
`torch.nn.functional.scaled_dot_product_attention`. With a null bias the
no-mask fast path is preserved so PyTorch can still dispatch the true flash
kernel; otherwise SDPA falls back to the memory-efficient or cuDNN backend,
which both accept an additive float bias with working gradients.

Replace the `ValueError` unit tests with numerical-equivalence tests against
the explicit attention path for 2D and 3D `input_size`. Docstrings for
`use_flash_attention` are updated to clarify that backend selection is
delegated to SDPA.

Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 4, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: fcfb7387-0008-44c1-918c-2036dedbe289

📥 Commits

Reviewing files that changed from the base of the PR and between b7d4786 and c5b2a1c.

📒 Files selected for processing (4)
  • monai/networks/blocks/crossattention.py
  • monai/networks/blocks/selfattention.py
  • tests/networks/blocks/test_crossattention.py
  • tests/networks/blocks/test_selfattention.py
✅ Files skipped from review due to trivial changes (1)
  • monai/networks/blocks/selfattention.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/networks/blocks/test_crossattention.py
  • monai/networks/blocks/crossattention.py

📝 Walkthrough

Walkthrough

Initialization-time checks preventing use_flash_attention with rel_pos_embedding were removed from CrossAttentionBlock and SABlock. Both blocks now synthesize an additive attention bias from relative positional embeddings and (when present) merge causal masking and user attn_mask into that bias, then pass it as attn_mask to torch.nn.functional.scaled_dot_product_attention (adjusting the SDPA is_causal argument when causal masking is merged). Tests were updated/added to assert numerical equivalence between flash and non-flash code paths for rel-pos, causal+rel-pos, and attn_mask+rel-pos scenarios.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 30.77% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed Title directly and accurately summarizes the main change: enabling rel_pos_embedding with flash attention in two key attention blocks.
Description check ✅ Passed Description includes all required sections: issue reference, clear explanation of changes, types of changes marked, and test coverage detailed.
Linked Issues check ✅ Passed Changes fully satisfy issue #7997: rel_pos_embedding now works with use_flash_attention via additive attn_mask, preserving fast paths when no bias needed.
Out of Scope Changes check ✅ Passed All changes directly address the stated objective: removing the ValueError restriction and implementing bias routing through SDPA. No extraneous modifications detected.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
Review rate limit: 7/8 reviews remaining, refill in 7 minutes and 30 seconds.

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@monai/networks/blocks/crossattention.py`:
- Around line 64-68: The docstring for use_flash_attention is misleading: it
claims that setting causal causes fallback to non-flash kernels, but the
implementation still dispatches scaled_dot_product_attention with is_causal=True
when no additive bias exists (see use_flash_attention, rel_pos_embedding and
is_causal=True usage in this module). Update the docstring to state that PyTorch
falls back only when a custom additive attention bias is present (for example a
merged relative-position or other additive bias), and clarify that pure causal
masking (with no additive bias) will still use the flash/SDPA fast path via
is_causal=True.

In `@monai/networks/blocks/selfattention.py`:
- Around line 66-70: Update the docstring for use_flash_attention in
selfattention.py to avoid claiming that setting causal always forces a fallback;
instead state that PyTorch falls back from the true flash kernel only when an
additive attention mask or bias is provided (e.g., custom attn_mask or built
positional/relative bias), and note that an internal is_causal=True flag (see
logic around the is_causal preservation in the block handling
rel_pos_embedding/causal between lines ~188-210) does not by itself force the
fallback.

In `@tests/networks/blocks/test_crossattention.py`:
- Around line 73-94: Extend the existing
test_rel_pos_embedding_with_flash_attention to include a case where the
CrossAttentionBlock is instantiated with causal=True (i.e., exercise the
causal-bias branch / is_causal_arg=True) while rel_pos_embedding is set to
RelPosEmbedding.DECOMPOSED and use_flash_attention=True; create a matching
reference block with use_flash_attention=False and causal=True, load the flash
block state into the reference block, run both in eval_mode on the same random
input (same seq_len computation and device handling), and
assert_allclose(out_flash, out_ref, atol=1e-4) to lock the causal behavior.

In `@tests/networks/blocks/test_selfattention.py`:
- Around line 71-91: Extend test_rel_pos_embedding_with_flash_attention to also
assert numerical equivalence between the flash and reference SABlock paths for
the two additional branches: (1) causal + rel_pos_embedding +
use_flash_attention (exercise the merged causal bias path) and (2) attn_mask +
rel_pos_embedding + use_flash_attention (exercise the user-attn-mask
merged-into-additive-bias path). For each case, create blocks via
SABlock(**input_param, use_flash_attention=True) and SABlock(...,
use_flash_attention=False), copy state_dict from flash to ref, run both in
eval_mode on the same random input, and assert_allclose(out_flash, out_ref,
atol=1e-4); for the attn_mask case provide an appropriate attention mask tensor
passed to the forward call to trigger the attn_mask branch. Ensure you reuse
RelPosEmbedding.DECOMPOSED, input_size patterns, device selection, and
comparisons like the existing test.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 55e75fbd-5397-4df9-9b85-5c0306c1f94f

📥 Commits

Reviewing files that changed from the base of the PR and between 65beb58 and b7d4786.

📒 Files selected for processing (4)
  • monai/networks/blocks/crossattention.py
  • monai/networks/blocks/selfattention.py
  • tests/networks/blocks/test_crossattention.py
  • tests/networks/blocks/test_selfattention.py

Comment thread monai/networks/blocks/crossattention.py Outdated
Comment thread monai/networks/blocks/selfattention.py Outdated
Comment thread tests/networks/blocks/test_crossattention.py
Comment thread tests/networks/blocks/test_selfattention.py
aymuos15 added 2 commits May 4, 2026 10:47
Address CodeRabbit review on PR Project-MONAI#8842:

- Narrow the use_flash_attention docstring in SABlock and
  CrossAttentionBlock so it reflects the actual implementation: pure
  causal masking keeps the fast path via is_causal=True; only an
  additive bias (rel_pos_embedding, or causal/attn_mask merged with
  another bias) forces SDPA to fall back to the memory-efficient or
  cuDNN backend.
- Extend the numerical-equivalence tests to cover the new merged-bias
  paths: causal=True + rel_pos_embedding for both blocks, and
  attn_mask + rel_pos_embedding for SABlock. All cases assert
  assert_allclose(out_flash, out_ref, atol=1e-4) on 2D and 3D inputs.

Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Enable relative positional embedding in flash attention

1 participant