We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 8afd111 commit 13b5c46Copy full SHA for 13b5c46
tesseract_jax/tesseract_compat.py
@@ -52,6 +52,19 @@ def unflatten_args(
52
if remove_static_args:
53
result = _prune_nones(result)
54
55
+ # Since jax 0.8, when tracing stuff without jit arrays are wrapped
56
+ # by TypedNdArray (thin wrapper around a numpy array); this snippet converts them
57
+ # back to ndarrays for downstream calculations.
58
+ try:
59
+ from jax._src.literals import TypedNdArray
60
+
61
+ result = jax.tree.map(
62
+ lambda v: v.val if isinstance(v, TypedNdArray) else v, result
63
+ )
64
65
+ except ImportError:
66
+ pass
67
68
return result
69
70
0 commit comments