Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 39 additions & 10 deletions keras/src/layers/preprocessing/string_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from keras.src.utils import backend_utils
from keras.src.utils.module_utils import tensorflow as tf

if backend.backend() == "torch":
import torch


@keras_export("keras.layers.StringLookup")
class StringLookup(IndexLookup):
Expand Down Expand Up @@ -382,13 +385,39 @@ def get_config(self):
return {**base_config, **config}

def call(self, inputs):
if isinstance(inputs, (tf.Tensor, tf.RaggedTensor, tf.SparseTensor)):
tf_inputs = True
else:
tf_inputs = False
if not isinstance(inputs, (np.ndarray, list, tuple)):
inputs = tf.convert_to_tensor(backend.convert_to_numpy(inputs))
outputs = super().call(inputs)
if not tf_inputs:
outputs = backend_utils.convert_tf_tensor(outputs)
return outputs
is_torch_backend = backend.backend() == "torch"

# Handle input conversion
inputs_for_processing = inputs
was_tf_input = isinstance(
inputs, (tf.Tensor, tf.RaggedTensor, tf.SparseTensor)
)

if is_torch_backend and isinstance(inputs, torch.Tensor):
inputs_for_processing = tf.convert_to_tensor(
inputs.detach().cpu().numpy()
)
elif isinstance(inputs, (np.ndarray, list, tuple)):
inputs_for_processing = tf.convert_to_tensor(inputs)
elif not was_tf_input:
inputs_for_processing = tf.convert_to_tensor(
backend.convert_to_numpy(inputs)
)

output = super().call(inputs_for_processing)

# Handle torch backend output conversion
if is_torch_backend and isinstance(
inputs, (torch.Tensor, np.ndarray, list, tuple)
):
numpy_outputs = output.numpy()
if self.invert:
return [n.decode(self.encoding) for n in numpy_outputs]
else:
return torch.from_numpy(numpy_outputs)

# other backends
if not was_tf_input:
output = backend_utils.convert_tf_tensor(output)

return output
30 changes: 30 additions & 0 deletions keras/src/layers/preprocessing/string_lookup_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,33 @@ def test_tensor_as_vocab(self):
)
output = layer(data)
self.assertAllClose(output, np.array([[1, 3, 4], [4, 0, 2]]))

@pytest.mark.skipif(backend.backend() != "torch", reason="Only torch")
def test_torch_backend_compatibility(self):
import torch

# Forward lookup: String -> number
forward_lookup = layers.StringLookup(
vocabulary=["a", "b", "c"], oov_token="[OOV]"
)
input_data_str = ["a", "b", "[OOV]", "d"]
output_numeric = forward_lookup(input_data_str)

# assert instance of output is torch.Tensor
self.assertIsInstance(output_numeric, torch.Tensor)
expected_numeric = torch.tensor([1, 2, 0, 0])
self.assertAllClose(output_numeric.cpu(), expected_numeric)

oov = "[OOV]"
# Inverse lookup: Number -> string
inverse_lookup = layers.StringLookup(
vocabulary=["a", "b", "c"], oov_token=oov, invert=True
)
input_data_int = torch.tensor([1, 2, 0], dtype=torch.int64)
output_string = inverse_lookup(input_data_int)
# Assert that the output is a list
# See : https://docs.pytorch.org/text/stable/_modules/torchtext/vocab/vocab.html#Vocab.lookup_tokens
# The torch equivalent implementation of this returns a list of strings
self.assertIsInstance(output_string, list)
expected_string = ["a", "b", "[OOV]"]
self.assertEqual(output_string, expected_string)