@@ -2962,12 +2962,6 @@ def __call__(
29622962 # perturbed path (identity attention)
29632963 batch_size , sequence_length , _ = hidden_states_ptb .shape
29642964
2965- if attention_mask is not None :
2966- attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
2967- # scaled_dot_product_attention expects attention_mask shape to be
2968- # (batch, heads, source_length, target_length)
2969- attention_mask = attention_mask .view (batch_size , attn .heads , - 1 , attention_mask .shape [- 1 ])
2970-
29712965 if attn .group_norm is not None :
29722966 hidden_states_ptb = attn .group_norm (hidden_states_ptb .transpose (1 , 2 )).transpose (1 , 2 )
29732967
@@ -3070,12 +3064,6 @@ def __call__(
30703064 # perturbed path (identity attention)
30713065 batch_size , sequence_length , _ = hidden_states_ptb .shape
30723066
3073- if attention_mask is not None :
3074- attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
3075- # scaled_dot_product_attention expects attention_mask shape to be
3076- # (batch, heads, source_length, target_length)
3077- attention_mask = attention_mask .view (batch_size , attn .heads , - 1 , attention_mask .shape [- 1 ])
3078-
30793067 if attn .group_norm is not None :
30803068 hidden_states_ptb = attn .group_norm (hidden_states_ptb .transpose (1 , 2 )).transpose (1 , 2 )
30813069
0 commit comments