Allow rel_pos_embedding with use_flash_attention in SABlock and CrossAttentionBlock#8842
Allow rel_pos_embedding with use_flash_attention in SABlock and CrossAttentionBlock#8842aymuos15 wants to merge 3 commits intoProject-MONAI:devfrom
Conversation
…roject-MONAI#7997) Lift the hard `ValueError` that prevented combining `rel_pos_embedding` with `use_flash_attention=True` in `SABlock` and `CrossAttentionBlock`. When a relative-position bias (and/or causal mask) is present, build an additive attention bias and pass it via `attn_mask` to `torch.nn.functional.scaled_dot_product_attention`. With a null bias the no-mask fast path is preserved so PyTorch can still dispatch the true flash kernel; otherwise SDPA falls back to the memory-efficient or cuDNN backend, which both accept an additive float bias with working gradients. Replace the `ValueError` unit tests with numerical-equivalence tests against the explicit attention path for 2D and 3D `input_size`. Docstrings for `use_flash_attention` are updated to clarify that backend selection is delegated to SDPA. Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (4)
✅ Files skipped from review due to trivial changes (1)
🚧 Files skipped from review as they are similar to previous changes (2)
📝 WalkthroughWalkthroughInitialization-time checks preventing Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Review rate limit: 7/8 reviews remaining, refill in 7 minutes and 30 seconds.Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@monai/networks/blocks/crossattention.py`:
- Around line 64-68: The docstring for use_flash_attention is misleading: it
claims that setting causal causes fallback to non-flash kernels, but the
implementation still dispatches scaled_dot_product_attention with is_causal=True
when no additive bias exists (see use_flash_attention, rel_pos_embedding and
is_causal=True usage in this module). Update the docstring to state that PyTorch
falls back only when a custom additive attention bias is present (for example a
merged relative-position or other additive bias), and clarify that pure causal
masking (with no additive bias) will still use the flash/SDPA fast path via
is_causal=True.
In `@monai/networks/blocks/selfattention.py`:
- Around line 66-70: Update the docstring for use_flash_attention in
selfattention.py to avoid claiming that setting causal always forces a fallback;
instead state that PyTorch falls back from the true flash kernel only when an
additive attention mask or bias is provided (e.g., custom attn_mask or built
positional/relative bias), and note that an internal is_causal=True flag (see
logic around the is_causal preservation in the block handling
rel_pos_embedding/causal between lines ~188-210) does not by itself force the
fallback.
In `@tests/networks/blocks/test_crossattention.py`:
- Around line 73-94: Extend the existing
test_rel_pos_embedding_with_flash_attention to include a case where the
CrossAttentionBlock is instantiated with causal=True (i.e., exercise the
causal-bias branch / is_causal_arg=True) while rel_pos_embedding is set to
RelPosEmbedding.DECOMPOSED and use_flash_attention=True; create a matching
reference block with use_flash_attention=False and causal=True, load the flash
block state into the reference block, run both in eval_mode on the same random
input (same seq_len computation and device handling), and
assert_allclose(out_flash, out_ref, atol=1e-4) to lock the causal behavior.
In `@tests/networks/blocks/test_selfattention.py`:
- Around line 71-91: Extend test_rel_pos_embedding_with_flash_attention to also
assert numerical equivalence between the flash and reference SABlock paths for
the two additional branches: (1) causal + rel_pos_embedding +
use_flash_attention (exercise the merged causal bias path) and (2) attn_mask +
rel_pos_embedding + use_flash_attention (exercise the user-attn-mask
merged-into-additive-bias path). For each case, create blocks via
SABlock(**input_param, use_flash_attention=True) and SABlock(...,
use_flash_attention=False), copy state_dict from flash to ref, run both in
eval_mode on the same random input, and assert_allclose(out_flash, out_ref,
atol=1e-4); for the attn_mask case provide an appropriate attention mask tensor
passed to the forward call to trigger the attn_mask branch. Ensure you reuse
RelPosEmbedding.DECOMPOSED, input_size patterns, device selection, and
comparisons like the existing test.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 55e75fbd-5397-4df9-9b85-5c0306c1f94f
📒 Files selected for processing (4)
monai/networks/blocks/crossattention.pymonai/networks/blocks/selfattention.pytests/networks/blocks/test_crossattention.pytests/networks/blocks/test_selfattention.py
Address CodeRabbit review on PR Project-MONAI#8842: - Narrow the use_flash_attention docstring in SABlock and CrossAttentionBlock so it reflects the actual implementation: pure causal masking keeps the fast path via is_causal=True; only an additive bias (rel_pos_embedding, or causal/attn_mask merged with another bias) forces SDPA to fall back to the memory-efficient or cuDNN backend. - Extend the numerical-equivalence tests to cover the new merged-bias paths: causal=True + rel_pos_embedding for both blocks, and attn_mask + rel_pos_embedding for SABlock. All cases assert assert_allclose(out_flash, out_ref, atol=1e-4) on 2D and 3D inputs. Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
Fixes #7997.
Lift the hard
ValueErrorthat prevented combiningrel_pos_embeddingwithuse_flash_attention=TrueinSABlockandCrossAttentionBlock. Issue #7997 tracks a suggestion originally raised in PR #7977 review comment: the relative-position bias can be routed through the additiveattn_maskargument oftorch.nn.functional.scaled_dot_product_attention(SDPA), at the cost of dropping out of the true flash kernel fast path.How it works
attn_mask=Noneis passed to SDPA so PyTorch can still dispatch the true flash kernel — the existing fast path is preserved.attn_mask. SDPA falls back to a backend that supports an additive float bias (typically the memory-efficient backend, or the math backend as a universal fallback). This is the trade-off acknowledged in the issue: not the real flash kernel, but still meaningfully faster and lower-memory than the explicitQKᵀ → softmax → Vpath.causal=Truecombined with a bias is handled by converting the booleancausal_maskto additive-infbias and disabling SDPA'sis_causal(since SDPA cannot combineis_causal=Truewith a customattn_mask). Pure causal (no bias, no user mask) still usesis_causal=Trueand the optimised path.save_attn=Truecontinues to raiseValueErrorwhen combined withuse_flash_attention=True, since SDPA does not expose the explicit attention matrix.Tests
The previous "raises
ValueError" tests are replaced with numerical-equivalence tests (assert_allclose(out_flash, out_ref, atol=1e-4)) for both 2D(16, 32)and 3D(8, 8, 8)input_size, comparing the flash path against the explicit attention path with shared weights. Same coverage added forCrossAttentionBlock.Docstrings for
use_flash_attentionare updated to clarify that backend selection is delegated to SDPA.Types of changes