Skip to content

Commit 13b5c46

Browse files
committed
Fix TypedNdArray issue
---- Co-authored by: @apaleyes
1 parent 8afd111 commit 13b5c46

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

tesseract_jax/tesseract_compat.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,19 @@ def unflatten_args(
5252
if remove_static_args:
5353
result = _prune_nones(result)
5454

55+
# Since jax 0.8, when tracing stuff without jit arrays are wrapped
56+
# by TypedNdArray (thin wrapper around a numpy array); this snippet converts them
57+
# back to ndarrays for downstream calculations.
58+
try:
59+
from jax._src.literals import TypedNdArray
60+
61+
result = jax.tree.map(
62+
lambda v: v.val if isinstance(v, TypedNdArray) else v, result
63+
)
64+
65+
except ImportError:
66+
pass
67+
5568
return result
5669

5770

0 commit comments

Comments
 (0)