@@ -47,6 +47,12 @@ def experimental_distribute_dataset(self, dataset, options=None):
4747        return  dataset 
4848
4949
50+ class  JaxDummyStrategy (DummyStrategy ):
51+     @property  
52+     def  num_replicas_in_sync (self ):
53+         return  len (jax .devices ("tpu" ))
54+ 
55+ 
5056class  DistributedEmbeddingTest (testing .TestCase , parameterized .TestCase ):
5157    def  setUp (self ):
5258        super ().setUp ()
@@ -80,9 +86,15 @@ def setUp(self):
8086            )
8187            print ("### num_replicas" , self ._strategy .num_replicas_in_sync )
8288            self .addCleanup (tf .tpu .experimental .shutdown_tpu_system , resolver )
89+         elif  keras .backend .backend () ==  "jax"  and  self .on_tpu :
90+             self ._strategy  =  JaxDummyStrategy ()
8391        else :
8492            self ._strategy  =  DummyStrategy ()
8593
94+         self .batch_size  =  (
95+             BATCH_SIZE_PER_CORE  *  self ._strategy .num_replicas_in_sync 
96+         )
97+ 
8698    def  run_with_strategy (self , fn , * args , jit_compile = False ):
8799        """Wrapper for running a function under a strategy.""" 
88100
@@ -120,31 +132,31 @@ def get_embedding_config(self, input_type, placement):
120132        feature_group ["feature1" ] =  config .FeatureConfig (
121133            name = "feature1" ,
122134            table = feature1_table ,
123-             input_shape = (BATCH_SIZE_PER_CORE , sequence_length ),
124-             output_shape = (BATCH_SIZE_PER_CORE , FEATURE1_EMBEDDING_OUTPUT_DIM ),
135+             input_shape = (self . batch_size , sequence_length ),
136+             output_shape = (self . batch_size , FEATURE1_EMBEDDING_OUTPUT_DIM ),
125137        )
126138        feature_group ["feature2" ] =  config .FeatureConfig (
127139            name = "feature2" ,
128140            table = feature2_table ,
129-             input_shape = (BATCH_SIZE_PER_CORE , sequence_length ),
130-             output_shape = (BATCH_SIZE_PER_CORE , FEATURE2_EMBEDDING_OUTPUT_DIM ),
141+             input_shape = (self . batch_size , sequence_length ),
142+             output_shape = (self . batch_size , FEATURE2_EMBEDDING_OUTPUT_DIM ),
131143        )
132144        return  {"feature_group" : feature_group }
133145
134146    def  create_inputs_weights_and_labels (
135-         self , batch_size ,  input_type , feature_configs , backend = None 
147+         self , input_type , feature_configs , backend = None 
136148    ):
137149        backend  =  backend  or  keras .backend .backend ()
138150
139151        if  input_type  ==  "dense" :
140152
141153            def  create_tensor (feature_config , op ):
142-                 sequence_length  =  feature_config .input_shape [- 1 ]
143-                 return  op ((batch_size , sequence_length ))
154+                 return  op (feature_config .input_shape )
144155
145156        elif  input_type  ==  "ragged" :
146157
147158            def  create_tensor (feature_config , op ):
159+                 batch_size  =  feature_config .input_shape [0 ]
148160                sequence_length  =  feature_config .input_shape [- 1 ]
149161                row_lengths  =  [
150162                    1  +  (i  %  sequence_length ) for  i  in  range (batch_size )
@@ -157,6 +169,7 @@ def create_tensor(feature_config, op):
157169        elif  input_type  ==  "sparse"  and  backend  ==  "tensorflow" :
158170
159171            def  create_tensor (feature_config , op ):
172+                 batch_size  =  feature_config .input_shape [0 ]
160173                sequence_length  =  feature_config .input_shape [- 1 ]
161174                indices  =  [[i , i  %  sequence_length ] for  i  in  range (batch_size )]
162175                return  tf .sparse .reorder (
@@ -170,6 +183,7 @@ def create_tensor(feature_config, op):
170183        elif  input_type  ==  "sparse"  and  backend  ==  "jax" :
171184
172185            def  create_tensor (feature_config , op ):
186+                 batch_size  =  feature_config .input_shape [0 ]
173187                sequence_length  =  feature_config .input_shape [- 1 ]
174188                indices  =  [[i , i  %  sequence_length ] for  i  in  range (batch_size )]
175189                return  jax_sparse .BCOO (
@@ -197,9 +211,7 @@ def create_tensor(feature_config, op):
197211            feature_configs ,
198212        )
199213        labels  =  keras .tree .map_structure (
200-             lambda  fc : np .ones (
201-                 (batch_size ,) +  fc .output_shape [1 :], dtype = np .float32 
202-             ),
214+             lambda  fc : np .ones (fc .output_shape , dtype = np .float32 ),
203215            feature_configs ,
204216        )
205217        return  inputs , weights , labels 
@@ -228,10 +240,9 @@ def test_basics(self, input_type, placement):
228240        ):
229241            self .skipTest ("Ragged and sparse are not compilable on TPU." )
230242
231-         batch_size  =  self ._strategy .num_replicas_in_sync  *  BATCH_SIZE_PER_CORE 
232243        feature_configs  =  self .get_embedding_config (input_type , placement )
233244        inputs , weights , _  =  self .create_inputs_weights_and_labels (
234-             batch_size ,  input_type , feature_configs 
245+             input_type , feature_configs 
235246        )
236247
237248        if  placement  ==  "sparsecore"  and  not  self .on_tpu :
@@ -264,11 +275,11 @@ def test_basics(self, input_type, placement):
264275
265276        self .assertEqual (
266277            res ["feature_group" ]["feature1" ].shape ,
267-             (batch_size , FEATURE1_EMBEDDING_OUTPUT_DIM ),
278+             (self . batch_size , FEATURE1_EMBEDDING_OUTPUT_DIM ),
268279        )
269280        self .assertEqual (
270281            res ["feature_group" ]["feature2" ].shape ,
271-             (batch_size , FEATURE2_EMBEDDING_OUTPUT_DIM ),
282+             (self . batch_size , FEATURE2_EMBEDDING_OUTPUT_DIM ),
272283        )
273284
274285    @parameterized .named_parameters ( 
@@ -289,16 +300,15 @@ def test_model_fit(self, input_type, use_weights):
289300                    f"{ input_type } { keras .backend .backend ()}  
290301                )
291302
292-         batch_size  =  self ._strategy .num_replicas_in_sync  *  BATCH_SIZE_PER_CORE 
293303        feature_configs  =  self .get_embedding_config (input_type , self .placement )
294304        train_inputs , train_weights , train_labels  =  (
295305            self .create_inputs_weights_and_labels (
296-                 batch_size ,  input_type , feature_configs , backend = "tensorflow" 
306+                 input_type , feature_configs , backend = "tensorflow" 
297307            )
298308        )
299309        test_inputs , test_weights , test_labels  =  (
300310            self .create_inputs_weights_and_labels (
301-                 batch_size ,  input_type , feature_configs , backend = "tensorflow" 
311+                 input_type , feature_configs , backend = "tensorflow" 
302312            )
303313        )
304314
@@ -482,12 +492,11 @@ def test_correctness(
482492        feature_config  =  config .FeatureConfig (
483493            name = "feature" ,
484494            table = table ,
485-             input_shape = (BATCH_SIZE_PER_CORE , sequence_length ),
486-             output_shape = (BATCH_SIZE_PER_CORE , EMBEDDING_OUTPUT_DIM ),
495+             input_shape = (self . batch_size , sequence_length ),
496+             output_shape = (self . batch_size , EMBEDDING_OUTPUT_DIM ),
487497        )
488498
489-         batch_size  =  self ._strategy .num_replicas_in_sync  *  BATCH_SIZE_PER_CORE 
490-         num_repeats  =  batch_size  //  2 
499+         num_repeats  =  self .batch_size  //  2 
491500        if  input_type  ==  "dense"  and  input_rank  ==  1 :
492501            inputs  =  keras .ops .convert_to_tensor ([2 , 3 ] *  num_repeats )
493502            weights  =  keras .ops .convert_to_tensor ([1.0 , 2.0 ] *  num_repeats )
@@ -512,14 +521,14 @@ def test_correctness(
512521                    tf .SparseTensor (
513522                        indices ,
514523                        [1 , 2 , 3 , 4 , 5 ] *  num_repeats ,
515-                         dense_shape = (batch_size , 4 ),
524+                         dense_shape = (self . batch_size , 4 ),
516525                    )
517526                )
518527                weights  =  tf .sparse .reorder (
519528                    tf .SparseTensor (
520529                        indices ,
521530                        [1.0 , 1.0 , 2.0 , 3.0 , 4.0 ] *  num_repeats ,
522-                         dense_shape = (batch_size , 4 ),
531+                         dense_shape = (self . batch_size , 4 ),
523532                    )
524533                )
525534            elif  keras .backend .backend () ==  "jax" :
@@ -528,15 +537,15 @@ def test_correctness(
528537                        jnp .asarray ([1 , 2 , 3 , 4 , 5 ] *  num_repeats ),
529538                        jnp .asarray (indices ),
530539                    ),
531-                     shape = (batch_size , 4 ),
540+                     shape = (self . batch_size , 4 ),
532541                    unique_indices = True ,
533542                )
534543                weights  =  jax_sparse .BCOO (
535544                    (
536545                        jnp .asarray ([1.0 , 1.0 , 2.0 , 3.0 , 4.0 ] *  num_repeats ),
537546                        jnp .asarray (indices ),
538547                    ),
539-                     shape = (batch_size , 4 ),
548+                     shape = (self . batch_size , 4 ),
540549                    unique_indices = True ,
541550                )
542551            else :
@@ -600,7 +609,7 @@ def test_correctness(
600609                layer .__call__ , inputs , weights , jit_compile = jit_compile 
601610            )
602611
603-         self .assertEqual (res .shape , (batch_size , EMBEDDING_OUTPUT_DIM ))
612+         self .assertEqual (res .shape , (self . batch_size , EMBEDDING_OUTPUT_DIM ))
604613
605614        tables  =  layer .get_embedding_tables ()
606615        emb  =  tables ["table" ]
@@ -644,26 +653,25 @@ def test_shared_table(self):
644653            "feature1" : config .FeatureConfig (
645654                name = "feature1" ,
646655                table = table1 ,
647-                 input_shape = (BATCH_SIZE_PER_CORE , 1 ),
648-                 output_shape = (BATCH_SIZE_PER_CORE , EMBEDDING_OUTPUT_DIM ),
656+                 input_shape = (self . batch_size , 1 ),
657+                 output_shape = (self . batch_size , EMBEDDING_OUTPUT_DIM ),
649658            ),
650659            "feature2" : config .FeatureConfig (
651660                name = "feature2" ,
652661                table = table1 ,
653-                 input_shape = (BATCH_SIZE_PER_CORE , 1 ),
654-                 output_shape = (BATCH_SIZE_PER_CORE , EMBEDDING_OUTPUT_DIM ),
662+                 input_shape = (self . batch_size , 1 ),
663+                 output_shape = (self . batch_size , EMBEDDING_OUTPUT_DIM ),
655664            ),
656665            "feature3" : config .FeatureConfig (
657666                name = "feature3" ,
658667                table = table1 ,
659-                 input_shape = (BATCH_SIZE_PER_CORE , 1 ),
660-                 output_shape = (BATCH_SIZE_PER_CORE , EMBEDDING_OUTPUT_DIM ),
668+                 input_shape = (self . batch_size , 1 ),
669+                 output_shape = (self . batch_size , EMBEDDING_OUTPUT_DIM ),
661670            ),
662671        }
663672
664-         batch_size  =  self ._strategy .num_replicas_in_sync  *  BATCH_SIZE_PER_CORE 
665673        inputs , _ , _  =  self .create_inputs_weights_and_labels (
666-             batch_size ,  "dense" , embedding_config 
674+             "dense" , embedding_config 
667675        )
668676
669677        with  self ._strategy .scope ():
@@ -676,13 +684,13 @@ def test_shared_table(self):
676684            self .assertLen (layer .trainable_variables , 1 )
677685
678686        self .assertEqual (
679-             res ["feature1" ].shape , (batch_size , EMBEDDING_OUTPUT_DIM )
687+             res ["feature1" ].shape , (self . batch_size , EMBEDDING_OUTPUT_DIM )
680688        )
681689        self .assertEqual (
682-             res ["feature2" ].shape , (batch_size , EMBEDDING_OUTPUT_DIM )
690+             res ["feature2" ].shape , (self . batch_size , EMBEDDING_OUTPUT_DIM )
683691        )
684692        self .assertEqual (
685-             res ["feature3" ].shape , (batch_size , EMBEDDING_OUTPUT_DIM )
693+             res ["feature3" ].shape , (self . batch_size , EMBEDDING_OUTPUT_DIM )
686694        )
687695
688696    def  test_mixed_placement (self ):
@@ -719,26 +727,25 @@ def test_mixed_placement(self):
719727            "feature1" : config .FeatureConfig (
720728                name = "feature1" ,
721729                table = table1 ,
722-                 input_shape = (BATCH_SIZE_PER_CORE , 1 ),
723-                 output_shape = (BATCH_SIZE_PER_CORE , embedding_output_dim1 ),
730+                 input_shape = (self . batch_size , 1 ),
731+                 output_shape = (self . batch_size , embedding_output_dim1 ),
724732            ),
725733            "feature2" : config .FeatureConfig (
726734                name = "feature2" ,
727735                table = table2 ,
728-                 input_shape = (BATCH_SIZE_PER_CORE , 1 ),
729-                 output_shape = (BATCH_SIZE_PER_CORE , embedding_output_dim2 ),
736+                 input_shape = (self . batch_size , 1 ),
737+                 output_shape = (self . batch_size , embedding_output_dim2 ),
730738            ),
731739            "feature3" : config .FeatureConfig (
732740                name = "feature3" ,
733741                table = table3 ,
734-                 input_shape = (BATCH_SIZE_PER_CORE , 1 ),
735-                 output_shape = (BATCH_SIZE_PER_CORE , embedding_output_dim3 ),
742+                 input_shape = (self . batch_size , 1 ),
743+                 output_shape = (self . batch_size , embedding_output_dim3 ),
736744            ),
737745        }
738746
739-         batch_size  =  self ._strategy .num_replicas_in_sync  *  BATCH_SIZE_PER_CORE 
740747        inputs , _ , _  =  self .create_inputs_weights_and_labels (
741-             batch_size ,  "dense" , embedding_config 
748+             "dense" , embedding_config 
742749        )
743750
744751        with  self ._strategy .scope ():
@@ -747,20 +754,19 @@ def test_mixed_placement(self):
747754        res  =  self .run_with_strategy (layer .__call__ , inputs )
748755
749756        self .assertEqual (
750-             res ["feature1" ].shape , (batch_size , embedding_output_dim1 )
757+             res ["feature1" ].shape , (self . batch_size , embedding_output_dim1 )
751758        )
752759        self .assertEqual (
753-             res ["feature2" ].shape , (batch_size , embedding_output_dim2 )
760+             res ["feature2" ].shape , (self . batch_size , embedding_output_dim2 )
754761        )
755762        self .assertEqual (
756-             res ["feature3" ].shape , (batch_size , embedding_output_dim3 )
763+             res ["feature3" ].shape , (self . batch_size , embedding_output_dim3 )
757764        )
758765
759766    def  test_save_load_model (self ):
760-         batch_size  =  self ._strategy .num_replicas_in_sync  *  BATCH_SIZE_PER_CORE 
761767        feature_configs  =  self .get_embedding_config ("dense" , self .placement )
762768        inputs , _ , _  =  self .create_inputs_weights_and_labels (
763-             batch_size ,  "dense" , feature_configs 
769+             "dense" , feature_configs 
764770        )
765771
766772        keras_inputs  =  keras .tree .map_structure (
0 commit comments