Skip to content

Commit a3a62c5

Browse files
authored
Improve MHA einsum (#781)
Efficiency update for einsum as mentioned in #772
1 parent 670f7a4 commit a3a62c5

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@
519519
" scores = torch.einsum(\"bhnd,bhmd->bhnm\", Q, K) / (self.head_dim ** 0.5)\n",
520520
"\n",
521521
" # Apply mask\n",
522-
" mask = self.mask[:n, :n].unsqueeze(0).unsqueeze(1).expand(b, self.num_heads, n, n)\n",
522+
" mask = self.mask[:n, :n]\n",
523523
" scores = scores.masked_fill(mask.bool(), -torch.inf)\n",
524524
"\n",
525525
" # Softmax and dropout\n",

0 commit comments

Comments
 (0)