Skip to content

Commit d7b9d5a

Browse files
authored
Fix docstring for DistributedEmbedding. (#118)
The JAX usage section was broken.
1 parent 89bb54f commit d7b9d5a

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

keras_rs/src/layers/embedding/base_distributed_embedding.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@ def step(data):
297297
result = strategy.run(step, args=(next(iterator),))
298298
299299
run_loop(iter(dataset))
300+
```
300301
301302
### Usage with JAX on TPU with SpareCore
302303

0 commit comments

Comments
 (0)