@@ -1040,7 +1040,33 @@ def backward(
10401040 grad_out : torch .Tensor ,
10411041 * args ,
10421042 ):
1043- raise NotImplementedError ("Backward pass is not implemented for TemplatedUlyssesAttention." )
1043+ parallel_config = _AttentionBackendRegistry ._parallel_config
1044+ ulysses_mesh = parallel_config ._ulysses_mesh
1045+ world_size = parallel_config .ulysses_degree
1046+ group = ulysses_mesh .get_group ()
1047+
1048+ B , S_LOCAL , H , D = grad_out .shape
1049+ H_LOCAL = H // world_size
1050+
1051+ grad_out = grad_out .reshape (B , S_LOCAL , world_size , H_LOCAL , D ).permute (2 , 1 , 0 , 3 , 4 ).contiguous ()
1052+ grad_out = _wait_tensor (funcol .all_to_all_single (grad_out , None , None , group = group ))
1053+ grad_out = grad_out .flatten (0 , 1 ).permute (1 , 0 , 2 , 3 ).contiguous ()
1054+
1055+ grad_query_op , grad_key_op , grad_value_op , * _ = ctx .backward_op (ctx .op_ctx , grad_out )
1056+
1057+ grad_query , grad_key , grad_value = (
1058+ x .reshape (B , world_size , S_LOCAL , H_LOCAL , D ).permute (1 , 3 , 0 , 2 , 4 ).contiguous ()
1059+ for x in (grad_query_op , grad_key_op , grad_value_op )
1060+ )
1061+ grad_query , grad_key , grad_value = (
1062+ _wait_tensor (funcol .all_to_all_single (x , None , None , group = group ))
1063+ for x in (grad_query , grad_key , grad_value )
1064+ )
1065+ grad_query , grad_key , grad_value = (
1066+ x .flatten (0 , 1 ).permute (1 , 2 , 0 , 3 ).contiguous () for x in (grad_query , grad_key , grad_value )
1067+ )
1068+
1069+ return grad_query , grad_key , grad_value , None , None , None , None , None , None , None , None
10441070
10451071
10461072def _templated_context_parallel_attention (
0 commit comments