Just realized this is potentially breaking:
|
"enable_gqa": enable_gqa, |
I believe enable_gqa was first released in PT 2.5.0. For versions before that, this flag should not be passed. Additionally, some user code may have wrappers around SDPA and do not expect the enable_gqa flag (think of TorchFunctionMode being used to intercept call).