-
Notifications
You must be signed in to change notification settings - Fork 475
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Bug report
attention_mla.MLA scales queries after projection,
maxtext/src/MaxText/layers/attention_mla.py
Line 804 in 3a17530
| query = jnp.concatenate([q_nope, q_pe], axis=-1) * self.softmax_scale |
cudnn_jax_flash_attention (implementation used when attention=cudnn_flash_jax) also hardcodes the scale maxtext/src/MaxText/layers/attention_op.py
Line 1509 in 3a17530
| scale=1.0 / math.sqrt(head_dim), |
This leads to incorrect attention results that do not match attention=dot_product and other implementations.
Logs/Output
No response
Environment Information
No response
Additional Context
No response
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working