[FlexAttention] allow custom mask mod#37692
[FlexAttention] allow custom mask mod#37692liangel-02 wants to merge 1 commit intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a new block_sparsity_hint parameter to the FlexAttentionMetadata class and modifies the attention mechanism to allow for custom mask modifications. The changes aim to provide more flexibility in defining attention patterns, including support for custom sparsity hints. The code has been reviewed and a critical issue has been identified.
| # (causal mask for decoder or bidirectional mask for encoder) | ||
| if self.causal: | ||
| has_custom_mask = self.logical_mask_mod is not causal_mask_mod | ||
| if self.causal or has_custom_mask: |
There was a problem hiding this comment.
The condition self.causal or has_custom_mask will always evaluate to True if has_custom_mask is True. This means that the code will always use self.get_causal_mask_mod() when a custom mask is present, regardless of the value of self.causal. This might not be the intended behavior, as the user might want to use a bidirectional mask with a custom modification. This could lead to unexpected or incorrect attention patterns.
To fix this, the logic should ensure that self.causal is only considered when a custom mask is not present. If a custom mask is present, it should override the causal mask behavior.
| if self.causal or has_custom_mask: | |
| if has_custom_mask: | |
| mask_mod = self.logical_mask_mod | |
| elif self.causal: | |
| mask_mod = self.get_causal_mask_mod() | |
| else: | |
| mask_mod = self.get_bidirectional_mask_mod() |
Signed-off-by: Angel Li <[email protected]>
LucasWilkinson
left a comment
There was a problem hiding this comment.
@drisspg do you think you can help review this?
updating FlexAttention impl to accept custom mask mod from users