@@ -954,20 +954,8 @@ export abstract class Layer extends serialization.Serializable {
954
954
// Ensure inputs are all the same type.
955
955
const inputsList = generic_utils . toList ( inputs ) ;
956
956
957
- let allAreSymbolic = true ;
958
- for ( const input of inputsList ) {
959
- if ( ! ( input instanceof SymbolicTensor ) ) {
960
- allAreSymbolic = false ;
961
- break ;
962
- }
963
- }
964
- let noneAreSymbolic = true ;
965
- for ( const input of inputsList ) {
966
- if ( input instanceof SymbolicTensor ) {
967
- noneAreSymbolic = false ;
968
- break ;
969
- }
970
- }
957
+ const allAreSymbolic = checkAllSymbolic ( inputs ) ;
958
+ const noneAreSymbolic = checkNoneSymbolic ( inputs ) ;
971
959
972
960
if ( allAreSymbolic === noneAreSymbolic ) {
973
961
throw new ValueError (
@@ -1017,31 +1005,12 @@ export abstract class Layer extends serialization.Serializable {
1017
1005
1018
1006
// Actually call the layer, collecting output(s), mask(s), and shape(s).
1019
1007
if ( noneAreSymbolic ) {
1020
- let output = this . call ( inputs as Tensor | Tensor [ ] , kwargs ) ;
1008
+ let output = this . call ( inputs , kwargs ) ;
1021
1009
1022
1010
// Apply masks to the output tensors if the layer supports it.
1023
1011
if ( this . supportsMasking ) {
1024
1012
// TODO(mattsoulanille): pass the input tensors' masks to computeMask
1025
- const outputMask = this . computeMask ( inputs as Tensor | Tensor [ ] ) ;
1026
- if ( output instanceof Array && outputMask instanceof Array ) {
1027
- if ( output . length !== outputMask . length ) {
1028
- throw new Error ( `${ this . name } output ${ output . length } tensors `
1029
- + `but ${ outputMask . length } masks for those tensors` ) ;
1030
- }
1031
- for ( let i = 0 ; i < output . length ; i ++ ) {
1032
- output [ i ] . kerasMask = outputMask [ i ] ;
1033
- }
1034
- } else if ( outputMask instanceof Array ) {
1035
- throw new Error ( `{this.name} output a single tensor `
1036
- + `but ${ outputMask . length } masks` ) ;
1037
- } else if ( output instanceof Array ) {
1038
- for ( const out of output ) {
1039
- out . kerasMask = outputMask . clone ( ) ;
1040
- }
1041
- outputMask . dispose ( ) ; // Only keep the clones to avoid leaking
1042
- } else {
1043
- output . kerasMask = outputMask ;
1044
- }
1013
+ this . setMaskMetadata ( inputs , output ) ;
1045
1014
}
1046
1015
1047
1016
// If the layer returns tensors from its inputs, unmodified,
@@ -1097,8 +1066,7 @@ export abstract class Layer extends serialization.Serializable {
1097
1066
If the input tensor(s) had no previous history,
1098
1067
this does nothing.
1099
1068
*/
1100
- this . addInboundNode (
1101
- inputs as SymbolicTensor | SymbolicTensor [ ] , output , null , null ,
1069
+ this . addInboundNode ( inputs , output , null , null ,
1102
1070
inputShape , outputShape , kwargs ) ;
1103
1071
this . _refCount ++ ;
1104
1072
@@ -1419,6 +1387,32 @@ export abstract class Layer extends serialization.Serializable {
1419
1387
return mask ;
1420
1388
}
1421
1389
1390
+ private setMaskMetadata ( inputs : Tensor | Tensor [ ] , outputs : Tensor | Tensor [ ] ,
1391
+ previous_mask ?: Tensor | Tensor [ ] ) : void {
1392
+ if ( ! this . supportsMasking ) {
1393
+ return ;
1394
+ }
1395
+
1396
+ const outputMasks = this . computeMask ( inputs , previous_mask ) ;
1397
+ if ( outputs instanceof Array && outputMasks instanceof Array ) {
1398
+ if ( outputs . length !== outputMasks . length ) {
1399
+ throw new Error ( `${ this . name } outputs ${ outputs . length } tensors `
1400
+ + `but ${ outputMasks . length } masks for those tensors` ) ;
1401
+ }
1402
+ for ( let i = 0 ; i < outputs . length ; i ++ ) {
1403
+ outputs [ i ] . kerasMask = outputMasks [ i ] ;
1404
+ }
1405
+ } else if ( outputMasks instanceof Array ) {
1406
+ throw new Error ( `{this.name} outputs a single tensor `
1407
+ + `but ${ outputMasks . length } masks` ) ;
1408
+ } else if ( outputs instanceof Array ) {
1409
+ throw new Error ( `{this.name} outputs ${ outputs . length } tensors `
1410
+ + `but only one mask` ) ;
1411
+ } else {
1412
+ outputs . kerasMask = outputMasks ;
1413
+ }
1414
+ }
1415
+
1422
1416
/**
1423
1417
* Internal method to create an inbound node for the layer.
1424
1418
*
@@ -1666,3 +1660,29 @@ export function getSourceInputs(
1666
1660
}
1667
1661
}
1668
1662
}
1663
+
1664
+ type MaybeSymbolic = SymbolicTensor | Tensor ;
1665
+
1666
+ function checkAllSymbolic ( tensors : MaybeSymbolic | MaybeSymbolic [ ]
1667
+ ) : tensors is SymbolicTensor | SymbolicTensor [ ] {
1668
+ let allAreSymbolic = true ;
1669
+ for ( const tensor of generic_utils . toList ( tensors ) ) {
1670
+ if ( ! ( tensor instanceof SymbolicTensor ) ) {
1671
+ allAreSymbolic = false ;
1672
+ break ;
1673
+ }
1674
+ }
1675
+ return allAreSymbolic ;
1676
+ }
1677
+
1678
+ function checkNoneSymbolic ( tensors : MaybeSymbolic | MaybeSymbolic [ ]
1679
+ ) : tensors is Tensor | Tensor [ ] {
1680
+ let noneAreSymbolic = true ;
1681
+ for ( const tensor of generic_utils . toList ( tensors ) ) {
1682
+ if ( tensor instanceof SymbolicTensor ) {
1683
+ noneAreSymbolic = false ;
1684
+ break ;
1685
+ }
1686
+ }
1687
+ return noneAreSymbolic ;
1688
+ }
0 commit comments