Skip to content

Commit 8879e72

Browse files
Make MultiHeadAttention use masks from query and value tensors (#7951)
BUG * Use describeWithFlags * Separate tests out of for loop * Make MHA use masks from query and value tensors * Fix lint * Refactor mask computation into a separate function
1 parent f44e224 commit 8879e72

File tree

6 files changed

+231
-165
lines changed

6 files changed

+231
-165
lines changed

tfjs-core/src/tensor.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,8 @@ export class Tensor<R extends Rank = Rank> implements TensorInfo {
275275
kept = false;
276276
/** The id of the scope this tensor is being tracked in. */
277277
scopeId: number;
278+
/** The keras mask that some keras layers attach to the tensor */
279+
kerasMask?: Tensor;
278280

279281
/**
280282
* Number of elements to skip in each dimension when indexing. See
@@ -442,6 +444,9 @@ export class Tensor<R extends Rank = Rank> implements TensorInfo {
442444
if (this.isDisposed) {
443445
return;
444446
}
447+
if (this.kerasMask) {
448+
this.kerasMask.dispose();
449+
}
445450
trackerFn().disposeTensor(this);
446451
this.isDisposedInternal = true;
447452
}

tfjs-layers/src/base_callbacks.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,8 @@ export function standardizeCallbacks(
488488
}
489489
// Convert custom callback configs to custom callback objects.
490490
const callbackConfigs =
491-
generic_utils.toList(callbacks) as CustomCallbackArgs[];
491+
generic_utils.toList<BaseCallback | CustomCallbackArgs>(
492+
callbacks) as CustomCallbackArgs[];
492493
return callbackConfigs.map(
493494
callbackConfig => new CustomCallback(callbackConfig, yieldEvery));
494495
}

tfjs-layers/src/engine/topology.ts

Lines changed: 67 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -751,19 +751,19 @@ export abstract class Layer extends serialization.Serializable {
751751
*/
752752
protected assertInputCompatibility(inputs: Tensor|Tensor[]|SymbolicTensor|
753753
SymbolicTensor[]): void {
754-
inputs = generic_utils.toList(inputs);
754+
const inputsList = generic_utils.toList(inputs);
755755
if (this.inputSpec == null || this.inputSpec.length === 0) {
756756
return;
757757
}
758758
const inputSpec = generic_utils.toList(this.inputSpec);
759-
if (inputs.length !== inputSpec.length) {
759+
if (inputsList.length !== inputSpec.length) {
760760
throw new ValueError(
761761
`Layer ${this.name} expects ${inputSpec.length} inputs, ` +
762-
`but it received ${inputs.length} input tensors. ` +
762+
`but it received ${inputsList.length} input tensors. ` +
763763
`Input received: ${inputs}`);
764764
}
765-
for (let inputIndex = 0; inputIndex < inputs.length; inputIndex++) {
766-
const x = inputs[inputIndex];
765+
for (let inputIndex = 0; inputIndex < inputsList.length; inputIndex++) {
766+
const x = inputsList[inputIndex];
767767
const spec: InputSpec = inputSpec[inputIndex];
768768
if (spec == null) {
769769
continue;
@@ -954,20 +954,8 @@ export abstract class Layer extends serialization.Serializable {
954954
// Ensure inputs are all the same type.
955955
const inputsList = generic_utils.toList(inputs);
956956

957-
let allAreSymbolic = true;
958-
for (const input of inputsList) {
959-
if (!(input instanceof SymbolicTensor)) {
960-
allAreSymbolic = false;
961-
break;
962-
}
963-
}
964-
let noneAreSymbolic = true;
965-
for (const input of inputsList) {
966-
if (input instanceof SymbolicTensor) {
967-
noneAreSymbolic = false;
968-
break;
969-
}
970-
}
957+
const allAreSymbolic = checkAllSymbolic(inputs);
958+
const noneAreSymbolic = checkNoneSymbolic(inputs);
971959

972960
if (allAreSymbolic === noneAreSymbolic) {
973961
throw new ValueError(
@@ -1017,8 +1005,13 @@ export abstract class Layer extends serialization.Serializable {
10171005

10181006
// Actually call the layer, collecting output(s), mask(s), and shape(s).
10191007
if (noneAreSymbolic) {
1020-
let output = this.call(inputs as Tensor | Tensor[], kwargs);
1021-
// TODO(michaelterry): Compute the outputMask
1008+
let output = this.call(inputs, kwargs);
1009+
1010+
// Apply masks to the output tensors if the layer supports it.
1011+
if (this.supportsMasking) {
1012+
// TODO(mattsoulanille): pass the input tensors' masks to computeMask
1013+
this.setMaskMetadata(inputs, output);
1014+
}
10221015

10231016
// If the layer returns tensors from its inputs, unmodified,
10241017
// we copy them to avoid loss of tensor metadata.
@@ -1073,8 +1066,7 @@ export abstract class Layer extends serialization.Serializable {
10731066
If the input tensor(s) had no previous history,
10741067
this does nothing.
10751068
*/
1076-
this.addInboundNode(
1077-
inputs as SymbolicTensor | SymbolicTensor[], output, null, null,
1069+
this.addInboundNode(inputs, output, null, null,
10781070
inputShape, outputShape, kwargs);
10791071
this._refCount++;
10801072

@@ -1395,6 +1387,32 @@ export abstract class Layer extends serialization.Serializable {
13951387
return mask;
13961388
}
13971389

1390+
private setMaskMetadata(inputs: Tensor|Tensor[], outputs: Tensor|Tensor[],
1391+
previousMask?: Tensor|Tensor[]): void {
1392+
if (!this.supportsMasking) {
1393+
return;
1394+
}
1395+
1396+
const outputMasks = this.computeMask(inputs, previousMask);
1397+
if (outputs instanceof Array && outputMasks instanceof Array) {
1398+
if (outputs.length !== outputMasks.length) {
1399+
throw new Error(`${this.name} outputs ${outputs.length} tensors `
1400+
+ `but ${outputMasks.length} masks for those tensors`);
1401+
}
1402+
for (let i = 0; i < outputs.length; i++) {
1403+
outputs[i].kerasMask = outputMasks[i];
1404+
}
1405+
} else if (outputMasks instanceof Array) {
1406+
throw new Error(`{this.name} outputs a single tensor `
1407+
+ `but ${outputMasks.length} masks`);
1408+
} else if (outputs instanceof Array) {
1409+
throw new Error(`{this.name} outputs ${outputs.length} tensors `
1410+
+ `but only one mask`);
1411+
} else {
1412+
outputs.kerasMask = outputMasks;
1413+
}
1414+
}
1415+
13981416
/**
13991417
* Internal method to create an inbound node for the layer.
14001418
*
@@ -1642,3 +1660,29 @@ export function getSourceInputs(
16421660
}
16431661
}
16441662
}
1663+
1664+
type MaybeSymbolic = SymbolicTensor | Tensor;
1665+
1666+
function checkAllSymbolic(tensors: MaybeSymbolic | MaybeSymbolic[]
1667+
): tensors is SymbolicTensor | SymbolicTensor[] {
1668+
let allAreSymbolic = true;
1669+
for (const tensor of generic_utils.toList(tensors)) {
1670+
if (!(tensor instanceof SymbolicTensor)) {
1671+
allAreSymbolic = false;
1672+
break;
1673+
}
1674+
}
1675+
return allAreSymbolic;
1676+
}
1677+
1678+
function checkNoneSymbolic(tensors: MaybeSymbolic | MaybeSymbolic[]
1679+
): tensors is Tensor | Tensor[] {
1680+
let noneAreSymbolic = true;
1681+
for (const tensor of generic_utils.toList(tensors)) {
1682+
if (tensor instanceof SymbolicTensor) {
1683+
noneAreSymbolic = false;
1684+
break;
1685+
}
1686+
}
1687+
return noneAreSymbolic;
1688+
}

tfjs-layers/src/layers/nlp/multihead_attention.ts

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
*/
2121

2222
/* Original source: keras/layers/attention/multi_head_attention.py */
23-
import { Tensor, einsum, linalg, mul, ones, serialization, tidy } from '@tensorflow/tfjs-core';
23+
import { Tensor, einsum, linalg, logicalAnd, mul, ones, serialization, tidy } from '@tensorflow/tfjs-core';
2424
// tslint:disable-next-line: no-imports-from-dist
2525
import { arraysEqual } from '@tensorflow/tfjs-core/dist/util_base';
2626

@@ -813,12 +813,20 @@ export class MultiHeadAttention extends Layer {
813813
return tidy(() => {
814814
let autoMask: Tensor;
815815

816+
const queryMask = query.kerasMask;
817+
const valueMask = value.kerasMask;
818+
if (queryMask != null) {
819+
autoMask = queryMask.expandDims(2); // Shape is [B, T, 1]
820+
}
821+
if (valueMask != null) {
822+
const mask = valueMask.expandDims(1); // Shape is [B, 1, S]
823+
autoMask = autoMask ? logicalAnd(autoMask, mask) : mask;
824+
}
816825
if (useCausalMask) {
817826
// the shape of the causal mask is [1, T, S]
818827
const mask = this.computeCausalMask(query, value);
819-
autoMask = mask;
828+
autoMask = autoMask ? logicalAnd(autoMask, mask) : mask;
820829
}
821-
822830
if (autoMask != null) {
823831
// Merge attentionMask & automatic mask, to shape [B, T, S]
824832
attentionMask = attentionMask ?

0 commit comments

Comments
 (0)