|  | 
| 15 | 15 |     table_stacking as jte_table_stacking, | 
| 16 | 16 | ) | 
| 17 | 17 | from jax_tpu_embedding.sparsecore.utils import utils as jte_utils | 
| 18 |  | -from keras.src import backend | 
| 19 | 18 | 
 | 
| 20 | 19 | from keras_rs.src import types | 
| 21 | 20 | from keras_rs.src.layers.embedding import base_distributed_embedding | 
| @@ -247,23 +246,6 @@ def _create_sparsecore_distribution( | 
| 247 | 246 |         ) | 
| 248 | 247 |         return sparsecore_distribution, sparsecore_layout | 
| 249 | 248 | 
 | 
| 250 |  | -    def _create_cpu_distribution( | 
| 251 |  | -        self, cpu_axis_name: str = "cpu" | 
| 252 |  | -    ) -> tuple[ | 
| 253 |  | -        keras.distribution.ModelParallel, keras.distribution.TensorLayout | 
| 254 |  | -    ]: | 
| 255 |  | -        """Share a variable across all CPU processes.""" | 
| 256 |  | -        cpu_devices = jax.devices("cpu") | 
| 257 |  | -        device_mesh = keras.distribution.DeviceMesh( | 
| 258 |  | -            (len(cpu_devices),), [cpu_axis_name], cpu_devices | 
| 259 |  | -        ) | 
| 260 |  | -        replicated_layout = keras.distribution.TensorLayout([], device_mesh) | 
| 261 |  | -        layout_map = keras.distribution.LayoutMap(device_mesh=device_mesh) | 
| 262 |  | -        cpu_distribution = keras.distribution.ModelParallel( | 
| 263 |  | -            layout_map=layout_map | 
| 264 |  | -        ) | 
| 265 |  | -        return cpu_distribution, replicated_layout | 
| 266 |  | - | 
| 267 | 249 |     def _add_sparsecore_weight( | 
| 268 | 250 |         self, | 
| 269 | 251 |         name: str, | 
| @@ -405,11 +387,6 @@ def sparsecore_build( | 
| 405 | 387 |         self._sparsecore_layout = sparsecore_layout | 
| 406 | 388 |         self._sparsecore_distribution = sparsecore_distribution | 
| 407 | 389 | 
 | 
| 408 |  | -        # Distribution for CPU operations. | 
| 409 |  | -        cpu_distribution, cpu_layout = self._create_cpu_distribution() | 
| 410 |  | -        self._cpu_distribution = cpu_distribution | 
| 411 |  | -        self._cpu_layout = cpu_layout | 
| 412 |  | - | 
| 413 | 390 |         mesh = sparsecore_distribution.device_mesh.backend_mesh | 
| 414 | 391 |         global_device_count = mesh.devices.size | 
| 415 | 392 |         num_sc_per_device = jte_utils.num_sparsecores_per_device( | 
| @@ -466,10 +443,6 @@ def sparsecore_build( | 
| 466 | 443 |         # Collect all stacked tables. | 
| 467 | 444 |         table_specs = embedding_utils.get_table_specs(feature_specs) | 
| 468 | 445 |         table_stacks = embedding_utils.get_table_stacks(table_specs) | 
| 469 |  | -        stacked_table_specs = { | 
| 470 |  | -            stack_name: stack[0].stacked_table_spec | 
| 471 |  | -            for stack_name, stack in table_stacks.items() | 
| 472 |  | -        } | 
| 473 | 446 | 
 | 
| 474 | 447 |         # Create variables for all stacked tables and slot variables. | 
| 475 | 448 |         with sparsecore_distribution.scope(): | 
| @@ -502,50 +475,6 @@ def sparsecore_build( | 
| 502 | 475 |             ) | 
| 503 | 476 |             self._iterations.overwrite_with_gradient = True | 
| 504 | 477 | 
 | 
| 505 |  | -        with cpu_distribution.scope(): | 
| 506 |  | -            # Create variables to track static buffer size and max IDs for each | 
| 507 |  | -            # table during preprocessing.  These variables are shared across all | 
| 508 |  | -            # processes on CPU.  We don't add these via `add_weight` because we | 
| 509 |  | -            # can't have them passed to the training function. | 
| 510 |  | -            replicated_zeros_initializer = ShardedInitializer( | 
| 511 |  | -                "zeros", cpu_layout | 
| 512 |  | -            ) | 
| 513 |  | - | 
| 514 |  | -            with backend.name_scope(self.name, caller=self): | 
| 515 |  | -                self._preprocessing_buffer_size = { | 
| 516 |  | -                    table_name: backend.Variable( | 
| 517 |  | -                        initializer=replicated_zeros_initializer, | 
| 518 |  | -                        shape=(), | 
| 519 |  | -                        dtype=backend.standardize_dtype("int32"), | 
| 520 |  | -                        trainable=False, | 
| 521 |  | -                        name=table_name + ":preprocessing:buffer_size", | 
| 522 |  | -                    ) | 
| 523 |  | -                    for table_name in stacked_table_specs.keys() | 
| 524 |  | -                } | 
| 525 |  | -                self._preprocessing_max_unique_ids_per_partition = { | 
| 526 |  | -                    table_name: backend.Variable( | 
| 527 |  | -                        shape=(), | 
| 528 |  | -                        name=table_name | 
| 529 |  | -                        + ":preprocessing:max_unique_ids_per_partition", | 
| 530 |  | -                        initializer=replicated_zeros_initializer, | 
| 531 |  | -                        dtype=backend.standardize_dtype("int32"), | 
| 532 |  | -                        trainable=False, | 
| 533 |  | -                    ) | 
| 534 |  | -                    for table_name in stacked_table_specs.keys() | 
| 535 |  | -                } | 
| 536 |  | - | 
| 537 |  | -                self._preprocessing_max_ids_per_partition = { | 
| 538 |  | -                    table_name: backend.Variable( | 
| 539 |  | -                        shape=(), | 
| 540 |  | -                        name=table_name | 
| 541 |  | -                        + ":preprocessing:max_ids_per_partition", | 
| 542 |  | -                        initializer=replicated_zeros_initializer, | 
| 543 |  | -                        dtype=backend.standardize_dtype("int32"), | 
| 544 |  | -                        trainable=False, | 
| 545 |  | -                    ) | 
| 546 |  | -                    for table_name in stacked_table_specs.keys() | 
| 547 |  | -                } | 
| 548 |  | - | 
| 549 | 478 |         self._config = jte_embedding_lookup.EmbeddingLookupConfiguration( | 
| 550 | 479 |             feature_specs, | 
| 551 | 480 |             mesh=mesh, | 
| @@ -660,125 +589,60 @@ def _sparsecore_preprocess( | 
| 660 | 589 |             mesh.devices.item(0) | 
| 661 | 590 |         ) | 
| 662 | 591 | 
 | 
| 663 |  | -        # Get current buffer size/max_ids. | 
| 664 |  | -        previous_max_ids_per_partition = keras.tree.map_structure( | 
| 665 |  | -            lambda max_ids_per_partition: max_ids_per_partition.value.item(), | 
| 666 |  | -            self._preprocessing_max_ids_per_partition, | 
| 667 |  | -        ) | 
| 668 |  | -        previous_max_unique_ids_per_partition = keras.tree.map_structure( | 
| 669 |  | -            lambda max_unique_ids_per_partition: ( | 
| 670 |  | -                max_unique_ids_per_partition.value.item() | 
| 671 |  | -            ), | 
| 672 |  | -            self._preprocessing_max_unique_ids_per_partition, | 
| 673 |  | -        ) | 
| 674 |  | -        previous_buffer_size = keras.tree.map_structure( | 
| 675 |  | -            lambda buffer_size: buffer_size.value.item(), | 
| 676 |  | -            self._preprocessing_buffer_size, | 
| 677 |  | -        ) | 
| 678 |  | - | 
| 679 | 592 |         preprocessed, stats = embedding_utils.stack_and_shard_samples( | 
| 680 | 593 |             self._config.feature_specs, | 
| 681 | 594 |             samples, | 
| 682 | 595 |             local_device_count, | 
| 683 | 596 |             global_device_count, | 
| 684 | 597 |             num_sc_per_device, | 
| 685 |  | -            static_buffer_size=previous_buffer_size, | 
| 686 | 598 |         ) | 
| 687 | 599 | 
 | 
| 688 |  | -        # Extract max unique IDs and buffer sizes. | 
| 689 |  | -        # We need to replicate this value across all local CPU devices. | 
| 690 | 600 |         if training: | 
|  | 601 | +            # Synchronize input statistics across all devices and update the | 
|  | 602 | +            # underlying stacked tables specs in the feature specs. | 
|  | 603 | +            prev_stats = embedding_utils.get_stacked_table_stats( | 
|  | 604 | +                self._config.feature_specs | 
|  | 605 | +            ) | 
|  | 606 | + | 
|  | 607 | +            # Take the maximum with existing stats. | 
|  | 608 | +            stats = keras.tree.map_structure(max, prev_stats, stats) | 
|  | 609 | + | 
|  | 610 | +            # Flatten the stats so we can more efficiently transfer them | 
|  | 611 | +            # between hosts.  We use jax.tree because we will later need to | 
|  | 612 | +            # unflatten. | 
|  | 613 | +            flat_stats, stats_treedef = jax.tree.flatten(stats) | 
|  | 614 | + | 
|  | 615 | +            # In the case of multiple local CPU devices per host, we need to | 
|  | 616 | +            # replicate the stats to placate JAX collectives. | 
| 691 | 617 |             num_local_cpu_devices = jax.local_device_count("cpu") | 
| 692 |  | -            local_max_ids_per_partition = { | 
| 693 |  | -                table_name: np.repeat( | 
| 694 |  | -                    # Maximum across all partitions and previous max. | 
| 695 |  | -                    np.maximum( | 
| 696 |  | -                        np.max(elems), | 
| 697 |  | -                        previous_max_ids_per_partition[table_name], | 
| 698 |  | -                    ), | 
| 699 |  | -                    num_local_cpu_devices, | 
| 700 |  | -                ) | 
| 701 |  | -                for table_name, elems in stats.max_ids_per_partition.items() | 
| 702 |  | -            } | 
| 703 |  | -            local_max_unique_ids_per_partition = { | 
| 704 |  | -                name: np.repeat( | 
| 705 |  | -                    # Maximum across all partitions and previous max. | 
| 706 |  | -                    np.maximum( | 
| 707 |  | -                        np.max(elems), | 
| 708 |  | -                        previous_max_unique_ids_per_partition[name], | 
| 709 |  | -                    ), | 
| 710 |  | -                    num_local_cpu_devices, | 
| 711 |  | -                ) | 
| 712 |  | -                for name, elems in stats.max_unique_ids_per_partition.items() | 
| 713 |  | -            } | 
| 714 |  | -            local_buffer_size = { | 
| 715 |  | -                table_name: np.repeat( | 
| 716 |  | -                    np.maximum( | 
| 717 |  | -                        np.max( | 
| 718 |  | -                            # Round values up to the next multiple of 8. | 
| 719 |  | -                            # Currently using this as a proxy for the actual | 
| 720 |  | -                            # required buffer size. | 
| 721 |  | -                            ((elems + 7) // 8) * 8 | 
| 722 |  | -                        ) | 
| 723 |  | -                        * global_device_count | 
| 724 |  | -                        * num_sc_per_device | 
| 725 |  | -                        * local_device_count | 
| 726 |  | -                        * num_sc_per_device, | 
| 727 |  | -                        previous_buffer_size[table_name], | 
| 728 |  | -                    ), | 
| 729 |  | -                    num_local_cpu_devices, | 
| 730 |  | -                ) | 
| 731 |  | -                for table_name, elems in stats.max_ids_per_partition.items() | 
| 732 |  | -            } | 
|  | 618 | +            tiled_stats = np.tile( | 
|  | 619 | +                np.array(flat_stats, dtype=np.int32), (num_local_cpu_devices, 1) | 
|  | 620 | +            ) | 
| 733 | 621 | 
 | 
| 734 | 622 |             # Aggregate variables across all processes/devices. | 
| 735 | 623 |             max_across_cpus = jax.pmap( | 
| 736 | 624 |                 lambda x: jax.lax.pmax(  # type: ignore[no-untyped-call] | 
| 737 | 625 |                     x, "all_cpus" | 
| 738 | 626 |                 ), | 
| 739 | 627 |                 axis_name="all_cpus", | 
| 740 |  | -                devices=self._cpu_layout.device_mesh.backend_mesh.devices, | 
| 741 |  | -            ) | 
| 742 |  | -            new_max_ids_per_partition = max_across_cpus( | 
| 743 |  | -                local_max_ids_per_partition | 
| 744 |  | -            ) | 
| 745 |  | -            new_max_unique_ids_per_partition = max_across_cpus( | 
| 746 |  | -                local_max_unique_ids_per_partition | 
|  | 628 | +                backend="cpu", | 
| 747 | 629 |             ) | 
| 748 |  | -            new_buffer_size = max_across_cpus(local_buffer_size) | 
| 749 |  | - | 
| 750 |  | -            # Assign new preprocessing parameters. | 
| 751 |  | -            with self._cpu_distribution.scope(): | 
| 752 |  | -                # For each process, all max ids/buffer sizes are replicated | 
| 753 |  | -                # across all local devices.  Take the value from the first | 
| 754 |  | -                # device. | 
| 755 |  | -                keras.tree.map_structure( | 
| 756 |  | -                    lambda var, values: var.assign(values[0]), | 
| 757 |  | -                    self._preprocessing_max_ids_per_partition, | 
| 758 |  | -                    new_max_ids_per_partition, | 
| 759 |  | -                ) | 
| 760 |  | -                keras.tree.map_structure( | 
| 761 |  | -                    lambda var, values: var.assign(values[0]), | 
| 762 |  | -                    self._preprocessing_max_unique_ids_per_partition, | 
| 763 |  | -                    new_max_unique_ids_per_partition, | 
| 764 |  | -                ) | 
| 765 |  | -                keras.tree.map_structure( | 
| 766 |  | -                    lambda var, values: var.assign(values[0]), | 
| 767 |  | -                    self._preprocessing_buffer_size, | 
| 768 |  | -                    new_buffer_size, | 
| 769 |  | -                ) | 
| 770 |  | -                # Update parameters in the underlying feature specs. | 
| 771 |  | -                int_max_ids_per_partition = keras.tree.map_structure( | 
| 772 |  | -                    lambda varray: varray.item(), new_max_ids_per_partition | 
| 773 |  | -                ) | 
| 774 |  | -                int_max_unique_ids_per_partition = keras.tree.map_structure( | 
| 775 |  | -                    lambda varray: varray.item(), | 
| 776 |  | -                    new_max_unique_ids_per_partition, | 
|  | 630 | +            flat_stats = max_across_cpus(tiled_stats)[0].tolist() | 
|  | 631 | +            stats = jax.tree.unflatten(stats_treedef, flat_stats) | 
|  | 632 | + | 
|  | 633 | +            # Update configuration and repeat preprocessing if stats changed. | 
|  | 634 | +            if stats != prev_stats: | 
|  | 635 | +                embedding_utils.update_stacked_table_stats( | 
|  | 636 | +                    self._config.feature_specs, stats | 
| 777 | 637 |                 ) | 
| 778 |  | -                embedding_utils.update_stacked_table_specs( | 
|  | 638 | + | 
|  | 639 | +                # Re-execute preprocessing with consistent input statistics. | 
|  | 640 | +                preprocessed, _ = embedding_utils.stack_and_shard_samples( | 
| 779 | 641 |                     self._config.feature_specs, | 
| 780 |  | -                    int_max_ids_per_partition, | 
| 781 |  | -                    int_max_unique_ids_per_partition, | 
|  | 642 | +                    samples, | 
|  | 643 | +                    local_device_count, | 
|  | 644 | +                    global_device_count, | 
|  | 645 | +                    num_sc_per_device, | 
| 782 | 646 |                 ) | 
| 783 | 647 | 
 | 
| 784 | 648 |         return {"inputs": preprocessed} | 
|  | 
0 commit comments