Open
Description
Summary
We recently landed support for grouped query attention via use enable_gqa
on sdpa, however this is only enabled on the flash attention backend. This leads to a weird situation where it could have been more beneficial for a user to have not used the enable GQA flag and called repeat interleave prior to calling SDPA in order to use the CUDNN net backend.
It looks Like there is explicit support for the GQA situation in the CUDI and NAPI, we should add support for this.
https://docs.nvidia.com/deeplearning/cudnn/latest/api/cudnn-graph-library.html#cudnn-backend-operation-reduction-descriptor
cc @msaroufim @mikaylagawarecki @jainapurva, @eqy, @Skylion007