Skip to content

[FlexAttention] allow custom mask mod#37692

Open
liangel-02 wants to merge 1 commit intovllm-project:mainfrom
liangel-02:flex
Open

[FlexAttention] allow custom mask mod#37692
liangel-02 wants to merge 1 commit intovllm-project:mainfrom
liangel-02:flex

Conversation

@liangel-02
Copy link
Contributor

@liangel-02 liangel-02 commented Mar 20, 2026

updating FlexAttention impl to accept custom mask mod from users

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new block_sparsity_hint parameter to the FlexAttentionMetadata class and modifies the attention mechanism to allow for custom mask modifications. The changes aim to provide more flexibility in defining attention patterns, including support for custom sparsity hints. The code has been reviewed and a critical issue has been identified.

# (causal mask for decoder or bidirectional mask for encoder)
if self.causal:
has_custom_mask = self.logical_mask_mod is not causal_mask_mod
if self.causal or has_custom_mask:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The condition self.causal or has_custom_mask will always evaluate to True if has_custom_mask is True. This means that the code will always use self.get_causal_mask_mod() when a custom mask is present, regardless of the value of self.causal. This might not be the intended behavior, as the user might want to use a bidirectional mask with a custom modification. This could lead to unexpected or incorrect attention patterns.

To fix this, the logic should ensure that self.causal is only considered when a custom mask is not present. If a custom mask is present, it should override the causal mask behavior.

Suggested change
if self.causal or has_custom_mask:
if has_custom_mask:
mask_mod = self.logical_mask_mod
elif self.causal:
mask_mod = self.get_causal_mask_mod()
else:
mask_mod = self.get_bidirectional_mask_mod()

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@drisspg do you think you can help review this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants