Skip to content

Incorrect indexing with jnp.int8 and Python scalar #31396

@dodgebc

Description

@dodgebc

Description

I would have thought these three lines would give the same output, with Jax handling the type conversions.

import jax.numpy as jnp
numbers = jnp.arange(10000*10).reshape(10000, 10)

print(numbers[1000, jnp.array(0, dtype=jnp.int32)])
print(numbers[jnp.array(1000, dtype=jnp.int32), jnp.array(0, dtype=jnp.int8)])
print(numbers[1000, jnp.array(0, dtype=jnp.int8)])

However, the output is

10000
10000
0

Is this expected to fail silently here? Also, I am not sure this requires a Python scalar... it seemed to be causing issues in a larger piece of code when the first index is also a jax array, but I haven't been able to get a small reproducer of that yet.

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.6.0
jaxlib: 0.6.0
numpy: 2.3.0
python: 3.13.3 (main, Apr 9 2025, 04:03:52) [Clang 20.1.0 ]
device info: NVIDIA GeForce RTX 4050 Laptop GPU-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='bendodge', release='5.15.153.1-microsoft-standard-WSL2', version='#1 SMP Fri Mar 29 23:14:13 UTC 2024', machine='x86_64')

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions