@@ -751,19 +751,19 @@ export abstract class Layer extends serialization.Serializable {
751
751
*/
752
752
protected assertInputCompatibility ( inputs : Tensor | Tensor [ ] | SymbolicTensor |
753
753
SymbolicTensor [ ] ) : void {
754
- inputs = generic_utils . toList ( inputs ) ;
754
+ const inputsList = generic_utils . toList ( inputs ) ;
755
755
if ( this . inputSpec == null || this . inputSpec . length === 0 ) {
756
756
return ;
757
757
}
758
758
const inputSpec = generic_utils . toList ( this . inputSpec ) ;
759
- if ( inputs . length !== inputSpec . length ) {
759
+ if ( inputsList . length !== inputSpec . length ) {
760
760
throw new ValueError (
761
761
`Layer ${ this . name } expects ${ inputSpec . length } inputs, ` +
762
- `but it received ${ inputs . length } input tensors. ` +
762
+ `but it received ${ inputsList . length } input tensors. ` +
763
763
`Input received: ${ inputs } ` ) ;
764
764
}
765
- for ( let inputIndex = 0 ; inputIndex < inputs . length ; inputIndex ++ ) {
766
- const x = inputs [ inputIndex ] ;
765
+ for ( let inputIndex = 0 ; inputIndex < inputsList . length ; inputIndex ++ ) {
766
+ const x = inputsList [ inputIndex ] ;
767
767
const spec : InputSpec = inputSpec [ inputIndex ] ;
768
768
if ( spec == null ) {
769
769
continue ;
@@ -954,20 +954,8 @@ export abstract class Layer extends serialization.Serializable {
954
954
// Ensure inputs are all the same type.
955
955
const inputsList = generic_utils . toList ( inputs ) ;
956
956
957
- let allAreSymbolic = true ;
958
- for ( const input of inputsList ) {
959
- if ( ! ( input instanceof SymbolicTensor ) ) {
960
- allAreSymbolic = false ;
961
- break ;
962
- }
963
- }
964
- let noneAreSymbolic = true ;
965
- for ( const input of inputsList ) {
966
- if ( input instanceof SymbolicTensor ) {
967
- noneAreSymbolic = false ;
968
- break ;
969
- }
970
- }
957
+ const allAreSymbolic = checkAllSymbolic ( inputs ) ;
958
+ const noneAreSymbolic = checkNoneSymbolic ( inputs ) ;
971
959
972
960
if ( allAreSymbolic === noneAreSymbolic ) {
973
961
throw new ValueError (
@@ -1017,8 +1005,13 @@ export abstract class Layer extends serialization.Serializable {
1017
1005
1018
1006
// Actually call the layer, collecting output(s), mask(s), and shape(s).
1019
1007
if ( noneAreSymbolic ) {
1020
- let output = this . call ( inputs as Tensor | Tensor [ ] , kwargs ) ;
1021
- // TODO(michaelterry): Compute the outputMask
1008
+ let output = this . call ( inputs , kwargs ) ;
1009
+
1010
+ // Apply masks to the output tensors if the layer supports it.
1011
+ if ( this . supportsMasking ) {
1012
+ // TODO(mattsoulanille): pass the input tensors' masks to computeMask
1013
+ this . setMaskMetadata ( inputs , output ) ;
1014
+ }
1022
1015
1023
1016
// If the layer returns tensors from its inputs, unmodified,
1024
1017
// we copy them to avoid loss of tensor metadata.
@@ -1073,8 +1066,7 @@ export abstract class Layer extends serialization.Serializable {
1073
1066
If the input tensor(s) had no previous history,
1074
1067
this does nothing.
1075
1068
*/
1076
- this . addInboundNode (
1077
- inputs as SymbolicTensor | SymbolicTensor [ ] , output , null , null ,
1069
+ this . addInboundNode ( inputs , output , null , null ,
1078
1070
inputShape , outputShape , kwargs ) ;
1079
1071
this . _refCount ++ ;
1080
1072
@@ -1395,6 +1387,32 @@ export abstract class Layer extends serialization.Serializable {
1395
1387
return mask ;
1396
1388
}
1397
1389
1390
+ private setMaskMetadata ( inputs : Tensor | Tensor [ ] , outputs : Tensor | Tensor [ ] ,
1391
+ previousMask ?: Tensor | Tensor [ ] ) : void {
1392
+ if ( ! this . supportsMasking ) {
1393
+ return ;
1394
+ }
1395
+
1396
+ const outputMasks = this . computeMask ( inputs , previousMask ) ;
1397
+ if ( outputs instanceof Array && outputMasks instanceof Array ) {
1398
+ if ( outputs . length !== outputMasks . length ) {
1399
+ throw new Error ( `${ this . name } outputs ${ outputs . length } tensors `
1400
+ + `but ${ outputMasks . length } masks for those tensors` ) ;
1401
+ }
1402
+ for ( let i = 0 ; i < outputs . length ; i ++ ) {
1403
+ outputs [ i ] . kerasMask = outputMasks [ i ] ;
1404
+ }
1405
+ } else if ( outputMasks instanceof Array ) {
1406
+ throw new Error ( `{this.name} outputs a single tensor `
1407
+ + `but ${ outputMasks . length } masks` ) ;
1408
+ } else if ( outputs instanceof Array ) {
1409
+ throw new Error ( `{this.name} outputs ${ outputs . length } tensors `
1410
+ + `but only one mask` ) ;
1411
+ } else {
1412
+ outputs . kerasMask = outputMasks ;
1413
+ }
1414
+ }
1415
+
1398
1416
/**
1399
1417
* Internal method to create an inbound node for the layer.
1400
1418
*
@@ -1642,3 +1660,29 @@ export function getSourceInputs(
1642
1660
}
1643
1661
}
1644
1662
}
1663
+
1664
+ type MaybeSymbolic = SymbolicTensor | Tensor ;
1665
+
1666
+ function checkAllSymbolic ( tensors : MaybeSymbolic | MaybeSymbolic [ ]
1667
+ ) : tensors is SymbolicTensor | SymbolicTensor [ ] {
1668
+ let allAreSymbolic = true ;
1669
+ for ( const tensor of generic_utils . toList ( tensors ) ) {
1670
+ if ( ! ( tensor instanceof SymbolicTensor ) ) {
1671
+ allAreSymbolic = false ;
1672
+ break ;
1673
+ }
1674
+ }
1675
+ return allAreSymbolic ;
1676
+ }
1677
+
1678
+ function checkNoneSymbolic ( tensors : MaybeSymbolic | MaybeSymbolic [ ]
1679
+ ) : tensors is Tensor | Tensor [ ] {
1680
+ let noneAreSymbolic = true ;
1681
+ for ( const tensor of generic_utils . toList ( tensors ) ) {
1682
+ if ( tensor instanceof SymbolicTensor ) {
1683
+ noneAreSymbolic = false ;
1684
+ break ;
1685
+ }
1686
+ }
1687
+ return noneAreSymbolic ;
1688
+ }
0 commit comments