@@ -32,7 +32,25 @@ class DistributedEmbedding(keras.layers.Layer):
3232
3333 ---
3434
35- ## Configuration
35+ `DistributedEmbedding` is a layer optimized for TPU chips with SparseCore
36+ and can dramatically improve the speed of embedding lookups and embedding
37+ training. It works by combining multiple lookups into one invocation, and by
38+ sharding the embedding tables across the available chips. Note that one will
39+ only see performance benefits for embedding tables that are large enough to
40+ to require sharding because they don't fit on a single chip. More details
41+ are provided in the "Placement" section below.
42+
43+ On other hardware, GPUs, CPUs and TPUs without SparseCore,
44+ `DistributedEmbedding` provides the same API without any specific
45+ acceleration. No particular distribution scheme is applied besides the one
46+ set via `keras.distribution.set_distribution`.
47+
48+ `DistributedEmbedding` embeds sequences of inputs and reduces them to a
49+ single embedding by applying a configurable combiner function.
50+
51+ ### Configuration
52+
53+ #### Features and tables
3654
3755 A `DistributedEmbedding` embedding layer is configured via a set of
3856 `keras_rs.layers.FeatureConfig` objects, which themselves refer to
@@ -50,11 +68,13 @@ class DistributedEmbedding(keras.layers.Layer):
5068 name="table1",
5169 vocabulary_size=TABLE1_VOCABULARY_SIZE,
5270 embedding_dim=TABLE1_EMBEDDING_SIZE,
71+ placement="auto",
5372 )
5473 table2 = keras_rs.layers.TableConfig(
5574 name="table2",
5675 vocabulary_size=TABLE2_VOCABULARY_SIZE,
5776 embedding_dim=TABLE2_EMBEDDING_SIZE,
77+ placement="auto",
5878 )
5979
6080 feature1 = keras_rs.layers.FeatureConfig(
@@ -78,22 +98,141 @@ class DistributedEmbedding(keras.layers.Layer):
7898 embedding = keras_rs.layers.DistributedEmbedding(feature_configs)
7999 ```
80100
81- ## Optimizers
101+ #### Optimizers
82102
83103 Each embedding table within `DistributedEmbedding` uses its own optimizer
84104 for training, which is independent from the optimizer set on the model via
85105 `model.compile()`.
86106
87107 Note that not all optimizers are supported. Currently, the following are
88- always supported (i.e. on all backends and accelerators) :
108+ supported on all backends and accelerators:
89109
90110 - `keras.optimizers.Adagrad`
91111 - `keras.optimizers.SGD`
92112
93- Additionally, not all parameters of the optimizers are supported (e.g. the
113+ The following are additionally available when using the TensorFlow backend:
114+
115+ - `keras.optimizers.Adam`
116+ - `keras.optimizers.Ftrl`
117+
118+ Also, not all parameters of the optimizers are supported (e.g. the
94119 `nesterov` option of `SGD`). An error is raised when an unsupported
95120 optimizer or an unsupported optimizer parameter is used.
96121
122+ #### Placement
123+
124+ Each embedding table within `DistributedEmbedding` can be either placed on
125+ the SparseCore chip or the default device placement for the accelerator
126+ (e.g. HBM of the Tensor Cores on TPU). This is controlled by the `placement`
127+ attribute of `keras_rs.layers.TableConfig`.
128+
129+ - A placement of `"sparsecore"` indicates that the table should be placed on
130+ the SparseCore chips. An error is raised if this option is selected and
131+ there are no SparseCore chips.
132+ - A placement of `"default_device"` indicates that the table should not be
133+ placed on SparseCore, even if available. Instead the table is placed on
134+ the device where the model normally goes, i.e. the HBM on TPUs and GPUs.
135+ In this case, if applicable, the table is distributed using the scheme set
136+ via `keras.distribution.set_distribution`. On GPUs, CPUs and TPUs without
137+ SparseCore, this is the only placement available, and is the one selected
138+ by `"auto"`.
139+ - A placement of `"auto"` indicates to use `"sparsecore"` if available, and
140+ `"default_device"` otherwise. This is the default when not specified.
141+
142+ To optimize performance on TPU:
143+
144+ - Tables that are so large that they need to be sharded should use the
145+ `"sparsecore"` placement.
146+ - Tables that are small enough should use `"default_device"` and should
147+ typically be replicated across TPUs by using the
148+ `keras.distribution.DataParallel` distribution option.
149+
150+ ### Usage with TensorFlow on TPU with SpareCore
151+
152+ #### Inputs
153+
154+ In addition to `tf.Tensor`, `DistributedEmbedding` accepts `tf.RaggedTensor`
155+ and `tf.SparseTensor` as inputs for the embedding lookups. Ragged tensors
156+ must be ragged in the dimension with index 1. Note that if weights are
157+ passed, each weight tensor must be of the same class as the inputs for that
158+ particular feature and use the exact same ragged row lenghts for ragged
159+ tensors, and the same indices for sparse tensors. All the output of
160+ `DistributedEmbedding` are dense tensors.
161+
162+ #### Setup
163+
164+ To use `DistributedEmbedding` on TPUs with TensorFlow, one must use a
165+ `tf.distribute.TPUStrategy`. The `DistributedEmbedding` layer must be
166+ created under the `TPUStrategy`.
167+
168+ ```python
169+ resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
170+ topology = tf.tpu.experimental.initialize_tpu_system(resolver)
171+ device_assignment = tf.tpu.experimental.DeviceAssignment.build(
172+ topology, num_replicas=resolver.get_tpu_system_metadata().num_cores
173+ )
174+ strategy = tf.distribute.TPUStrategy(
175+ resolver, experimental_device_assignment=device_assignment
176+ )
177+
178+ with strategy.scope():
179+ embedding = keras_rs.layers.DistributedEmbedding(feature_configs)
180+ ```
181+
182+ #### Usage in a Keras model
183+
184+ To use Keras' `model.fit()`, one must compile the model under the
185+ `TPUStrategy`. Then, `model.fit()`, `model.evaluate()` or `model.predict()`
186+ can be called directly. The Keras model takes care of running the model
187+ using the strategy and also automatically distributes the dataset.
188+
189+ ```python
190+ with strategy.scope():
191+ embedding = keras_rs.layers.DistributedEmbedding(feature_configs)
192+ model = create_model(embedding)
193+ model.compile(loss=keras.losses.MeanSquaredError(), optimizer="adam")
194+
195+ model.fit(dataset, epochs=10)
196+ ```
197+
198+ #### Direct invocation
199+
200+ `DistributedEmbedding` must be invoked via a `strategy.run` call nested in a
201+ `tf.function`.
202+
203+ ```python
204+ @tf.function
205+ def embedding_wrapper(tf_fn_inputs, tf_fn_weights=None):
206+ def strategy_fn(st_fn_inputs, st_fn_weights):
207+ return embedding(st_fn_inputs, st_fn_weights)
208+
209+ return strategy.run(strategy_fn, args=(tf_fn_inputs, tf_fn_weights)))
210+
211+ embedding_wrapper(my_inputs, my_weights)
212+ ```
213+
214+ When using a dataset, the dataset must be distributed. The iterator can then
215+ be passed to the `tf.function` that uses `strategy.run`.
216+
217+ ```python
218+ dataset = strategy.experimental_distribute_dataset(dataset)
219+
220+ @tf.function
221+ def run_loop(iterator):
222+ def step(data):
223+ (inputs, weights), labels = data
224+ with tf.GradientTape() as tape:
225+ result = embedding(inputs, weights)
226+ loss = keras.losses.mean_squared_error(labels, result)
227+ tape.gradient(loss, embedding.trainable_variables)
228+ return result
229+
230+ for _ in tf.range(4):
231+ result = strategy.run(step, args=(next(iterator),))
232+
233+ run_loop(iter(dataset))
234+ ```
235+
97236 Args:
98237 feature_configs: A nested structure of `keras_rs.layers.FeatureConfig`.
99238 table_stacking: The table stacking to use. `None` means no table
@@ -282,7 +421,7 @@ def preprocess(
282421 to feeding data into a model.
283422
284423 An example usage might look like:
285- ```
424+ ```python
286425 # Create the embedding layer.
287426 embedding_layer = DistributedEmbedding(feature_configs)
288427
0 commit comments