Skip to content

Conversation

@hertschuh
Copy link
Collaborator

No description provided.

@hertschuh hertschuh force-pushed the embedding_tf branch 3 times, most recently from 4939b13 to 2bd6f5a Compare May 21, 2025 20:14
Copy link
Collaborator

@cantonios cantonios left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approval with a couple minor nits in-line.

In addition to `tf.Tensor`, `DistributedEmbedding` accepts `tf.RaggedTensor`
and `tf.SparseTensor` as inputs for the embedding lookups. Ragged tensors
must be ragged in dimension 1. Note that if weights are passed, each weight
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is dimension 1 the dimension with index 0? Does TF actually support any other kind of raggedness?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant dimension with index 1. TF supports any ragged dimension and in fact multiple ragged dimensions, a tensor can be twice ragged.

https://www.tensorflow.org/api_docs/python/tf/RaggedTensor#multiple_ragged_dimensions

https://www.tensorflow.org/api_docs/python/tf/RaggedTensor#attributes (see ragged_rank)

In addition to `tf.Tensor`, `DistributedEmbedding` accepts `tf.RaggedTensor`
and `tf.SparseTensor` as inputs for the embedding lookups. Ragged tensors
must be ragged in dimension 1. Note that if weights are passed, each weight
tensor must be of the same type as the inputs for that particular feature
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"same type" might be a bit confusing because we expect the datatype for indices to be integers, but weights to be floats.

Same class?

@hertschuh hertschuh merged commit 31d881b into keras-team:main May 21, 2025
5 checks passed
hertschuh added a commit that referenced this pull request Jun 3, 2025
* Ignore shard_map attr error in mypy. (#97)

* Added TF specific documentation to `DistributedEmbedding`. (#94)

* Fix symbolic calls for `EmbedReduce`. (#98)

`EmbedReduce` was inheriting the behavior from `Embedding` and not correctly applying the reduction.

* Move `DistributedEmbedding` declaration to its own file. (#99)

Having it in `__init__.py` doesn't play nice with pytype.

* Remove dependency on `tree` and use `keras.tree`. (#100)

Keras can already depend on either `dmtree` or `optree` and use whichever is best or available on the current platform.

* Only enable JAX on linux_x86_64. (#101)

* Add out_sharding argument WrappedKerasInitializer. (#102)

This is for forward-compatibility.  Latest versions of JAX introduce the
`out_sharding` argument.

* Use Python 3.10 style type annotations. (#104)

Now that we require Python 3.10, we can use the shorter annotation style, which should improve the readability of the documentation.

* Do not bundle test utils in wheel. (#105)

* Update version number to 0.2.1 (#106)

As 0.2.0 was just released.

* Fix invalid escape sequence in unit test. (#108)

* Replace leftover `unflatten_as` to `pack_sequence_as`. (#109)

This instance was missed as it is only run on TPU.

* Make the declaration of `Nested` compatible with pytype. (#110)

Which doesn't support `|` between forward declarations using a string.

* Add ragged support for default_device placement on JAX. (#107)

Requires calling `preprocess`.

Internally, we currently convert ragged inputs to dense before passing
to the embedding call(...) function.

* Add documentation for using DistributedEmbedding with JAX. (#111)

* `api_gen` now excludes backend specific code. (#103)

This:
- Allows development (`api_gen` / git presubmit hooks) without all backends and backend specific dependencies installed and working. For instance, jax_tpu_embedding currently doesn't import on MacOS Sequoia, this allows running `api_gen` regardless.
- Makes sure we don't accidentally create and honor exports that are backend specific.

* Enable preprocess calls with symbolic input tensors. (#113)

This allows us to more-easily create functional models via:
```python
preprocessed_inputs = distributed_embedding.preprocess(symbolic_inputs, symbolic_weights)
outputs = distributed_embedding(preprocessed_inputs)
model = keras.Model(inputs=preprocessed_inputs, outputs=outputs)
```

* Check for jax_tpu_embedding on JAX backend. (#114)

This is to allow users to potentially run Keras RS _without_ the
dependency.

If a user doesn't have `jax-tpu-embedding` installed, but are on
`linux_x86_64` and has a sparsecore-capable TPU available, and
if they try to use `auto` or `sparsecore` placement with distributed
embedding, will raise an error informing them to install the
dependency.

---------

Co-authored-by: C. Antonio Sánchez <[email protected]>
Co-authored-by: hertschuh <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants