Open
Description
Hi!
I am trying to compute the JVP of attention (FlashAttention), but I end up getting this NotImplementedError.
NotImplementedError: Trying to use forward AD with _scaled_dot_product_efficient_attention that does not support it because it has not been implemented yet.
Please file an issue to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml so that we can prioritize its implementation
I would greatly appreciate it if it is possible to implement this! :)
Thank you so much! Looking forward to hearing a response!
cc @ezyang @albanD @gqchen @nikitaved @soulitzer @Varal7 @xmfan