Skip to content

Commit 111048c

Browse files
Refactor mask computation into a separate function
1 parent bdf6ed1 commit 111048c

File tree

2 files changed

+58
-38
lines changed

2 files changed

+58
-38
lines changed

tfjs-layers/src/engine/topology.ts

Lines changed: 57 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -954,20 +954,8 @@ export abstract class Layer extends serialization.Serializable {
954954
// Ensure inputs are all the same type.
955955
const inputsList = generic_utils.toList(inputs);
956956

957-
let allAreSymbolic = true;
958-
for (const input of inputsList) {
959-
if (!(input instanceof SymbolicTensor)) {
960-
allAreSymbolic = false;
961-
break;
962-
}
963-
}
964-
let noneAreSymbolic = true;
965-
for (const input of inputsList) {
966-
if (input instanceof SymbolicTensor) {
967-
noneAreSymbolic = false;
968-
break;
969-
}
970-
}
957+
const allAreSymbolic = checkAllSymbolic(inputs);
958+
const noneAreSymbolic = checkNoneSymbolic(inputs);
971959

972960
if (allAreSymbolic === noneAreSymbolic) {
973961
throw new ValueError(
@@ -1017,31 +1005,12 @@ export abstract class Layer extends serialization.Serializable {
10171005

10181006
// Actually call the layer, collecting output(s), mask(s), and shape(s).
10191007
if (noneAreSymbolic) {
1020-
let output = this.call(inputs as Tensor | Tensor[], kwargs);
1008+
let output = this.call(inputs, kwargs);
10211009

10221010
// Apply masks to the output tensors if the layer supports it.
10231011
if (this.supportsMasking) {
10241012
// TODO(mattsoulanille): pass the input tensors' masks to computeMask
1025-
const outputMask = this.computeMask(inputs as Tensor | Tensor[]);
1026-
if (output instanceof Array && outputMask instanceof Array) {
1027-
if (output.length !== outputMask.length) {
1028-
throw new Error(`${this.name} output ${output.length} tensors `
1029-
+ `but ${outputMask.length} masks for those tensors`);
1030-
}
1031-
for (let i = 0; i < output.length; i++) {
1032-
output[i].kerasMask = outputMask[i];
1033-
}
1034-
} else if (outputMask instanceof Array) {
1035-
throw new Error(`{this.name} output a single tensor `
1036-
+ `but ${outputMask.length} masks`);
1037-
} else if (output instanceof Array) {
1038-
for (const out of output) {
1039-
out.kerasMask = outputMask.clone();
1040-
}
1041-
outputMask.dispose(); // Only keep the clones to avoid leaking
1042-
} else {
1043-
output.kerasMask = outputMask;
1044-
}
1013+
this.setMaskMetadata(inputs, output);
10451014
}
10461015

10471016
// If the layer returns tensors from its inputs, unmodified,
@@ -1097,8 +1066,7 @@ export abstract class Layer extends serialization.Serializable {
10971066
If the input tensor(s) had no previous history,
10981067
this does nothing.
10991068
*/
1100-
this.addInboundNode(
1101-
inputs as SymbolicTensor | SymbolicTensor[], output, null, null,
1069+
this.addInboundNode(inputs, output, null, null,
11021070
inputShape, outputShape, kwargs);
11031071
this._refCount++;
11041072

@@ -1419,6 +1387,32 @@ export abstract class Layer extends serialization.Serializable {
14191387
return mask;
14201388
}
14211389

1390+
private setMaskMetadata(inputs: Tensor|Tensor[], outputs: Tensor|Tensor[],
1391+
previous_mask?: Tensor|Tensor[]): void {
1392+
if (!this.supportsMasking) {
1393+
return;
1394+
}
1395+
1396+
const outputMasks = this.computeMask(inputs, previous_mask);
1397+
if (outputs instanceof Array && outputMasks instanceof Array) {
1398+
if (outputs.length !== outputMasks.length) {
1399+
throw new Error(`${this.name} outputs ${outputs.length} tensors `
1400+
+ `but ${outputMasks.length} masks for those tensors`);
1401+
}
1402+
for (let i = 0; i < outputs.length; i++) {
1403+
outputs[i].kerasMask = outputMasks[i];
1404+
}
1405+
} else if (outputMasks instanceof Array) {
1406+
throw new Error(`{this.name} outputs a single tensor `
1407+
+ `but ${outputMasks.length} masks`);
1408+
} else if (outputs instanceof Array) {
1409+
throw new Error(`{this.name} outputs ${outputs.length} tensors `
1410+
+ `but only one mask`);
1411+
} else {
1412+
outputs.kerasMask = outputMasks;
1413+
}
1414+
}
1415+
14221416
/**
14231417
* Internal method to create an inbound node for the layer.
14241418
*
@@ -1666,3 +1660,29 @@ export function getSourceInputs(
16661660
}
16671661
}
16681662
}
1663+
1664+
type MaybeSymbolic = SymbolicTensor | Tensor;
1665+
1666+
function checkAllSymbolic(tensors: MaybeSymbolic | MaybeSymbolic[]
1667+
): tensors is SymbolicTensor | SymbolicTensor[] {
1668+
let allAreSymbolic = true;
1669+
for (const tensor of generic_utils.toList(tensors)) {
1670+
if (!(tensor instanceof SymbolicTensor)) {
1671+
allAreSymbolic = false;
1672+
break;
1673+
}
1674+
}
1675+
return allAreSymbolic;
1676+
}
1677+
1678+
function checkNoneSymbolic(tensors: MaybeSymbolic | MaybeSymbolic[]
1679+
): tensors is Tensor | Tensor[] {
1680+
let noneAreSymbolic = true;
1681+
for (const tensor of generic_utils.toList(tensors)) {
1682+
if (tensor instanceof SymbolicTensor) {
1683+
noneAreSymbolic = false;
1684+
break;
1685+
}
1686+
}
1687+
return noneAreSymbolic;
1688+
}

tfjs-layers/src/utils/generic_utils.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ export function singletonOrArray<T>(xs: T[]): T|T[] {
7676
* @param x target object to be normalized.
7777
*/
7878
// tslint:disable-next-line:no-any
79-
export function toList(x: any): any[] {
79+
export function toList<T>(x: T|T[]): T[] {
8080
if (Array.isArray(x)) {
8181
return x;
8282
}

0 commit comments

Comments
 (0)