|
20 | 20 | from pytensor.printing import Printer, pprint, set_precedence |
21 | 21 | from pytensor.scalar.basic import ScalarConstant |
22 | 22 | from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length |
23 | | -from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value |
| 23 | +from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value, nonzero |
24 | 24 | from pytensor.tensor.elemwise import DimShuffle |
25 | | -from pytensor.tensor.exceptions import ( |
26 | | - AdvancedIndexingError, |
27 | | - NotScalarConstantError, |
28 | | - ShapeError, |
29 | | -) |
| 25 | +from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError |
30 | 26 | from pytensor.tensor.math import clip |
31 | | -from pytensor.tensor.shape import Reshape, specify_broadcastable |
| 27 | +from pytensor.tensor.shape import Reshape, shape_i, specify_broadcastable |
32 | 28 | from pytensor.tensor.type import ( |
33 | 29 | TensorType, |
34 | 30 | bscalar, |
@@ -2584,26 +2580,47 @@ def R_op(self, inputs, eval_points): |
2584 | 2580 | return self.make_node(eval_points[0], *inputs[1:]).outputs |
2585 | 2581 |
|
2586 | 2582 | def infer_shape(self, fgraph, node, ishapes): |
2587 | | - indices = node.inputs[1:] |
2588 | | - index_shapes = list(ishapes[1:]) |
2589 | | - for i, idx in enumerate(indices): |
2590 | | - if ( |
| 2583 | + def is_bool_index(idx): |
| 2584 | + return ( |
2591 | 2585 | isinstance(idx, (np.bool_, bool)) |
2592 | 2586 | or getattr(idx, "dtype", None) == "bool" |
2593 | | - ): |
2594 | | - raise ShapeError( |
2595 | | - "Shape inference for boolean indices is not implemented" |
| 2587 | + ) |
| 2588 | + |
| 2589 | + indices = node.inputs[1:] |
| 2590 | + index_shapes = [] |
| 2591 | + for idx, ishape in zip(indices, ishapes[1:]): |
| 2592 | + # Mixed bool indexes are converted to nonzero entries |
| 2593 | + if is_bool_index(idx): |
| 2594 | + index_shapes.extend( |
| 2595 | + (shape_i(nz_dim, 0, fgraph=fgraph),) for nz_dim in nonzero(idx) |
2596 | 2596 | ) |
2597 | 2597 | # The `ishapes` entries for `SliceType`s will be None, and |
2598 | 2598 | # we need to give `indexed_result_shape` the actual slices. |
2599 | | - if isinstance(getattr(idx, "type", None), SliceType): |
2600 | | - index_shapes[i] = idx |
| 2599 | + elif isinstance(getattr(idx, "type", None), SliceType): |
| 2600 | + index_shapes.append(idx) |
| 2601 | + else: |
| 2602 | + index_shapes.append(ishape) |
2601 | 2603 |
|
2602 | | - res_shape = indexed_result_shape( |
2603 | | - ishapes[0], index_shapes, indices_are_shapes=True |
| 2604 | + res_shape = list( |
| 2605 | + indexed_result_shape(ishapes[0], index_shapes, indices_are_shapes=True) |
2604 | 2606 | ) |
| 2607 | + |
| 2608 | + adv_indices = [idx for idx in indices if not is_basic_idx(idx)] |
| 2609 | + bool_indices = [idx for idx in adv_indices if is_bool_index(idx)] |
| 2610 | + |
| 2611 | + # Special logic when the only advanced index group is of bool type. |
| 2612 | + # We can replace the nonzeros by a sum of the whole bool variable. |
| 2613 | + if len(bool_indices) == 1 and len(adv_indices) == 1: |
| 2614 | + [bool_index] = bool_indices |
| 2615 | + # Find the output dim associated with the bool index group |
| 2616 | + # Because there are no more advanced index groups, there is exactly |
| 2617 | + # one output dim per index variable up to the bool group. |
| 2618 | + # Note: Scalar integer indexing counts as advanced indexing. |
| 2619 | + start_dim = indices.index(bool_index) |
| 2620 | + res_shape[start_dim] = bool_index.sum() |
| 2621 | + |
2605 | 2622 | assert node.outputs[0].ndim == len(res_shape) |
2606 | | - return [list(res_shape)] |
| 2623 | + return [res_shape] |
2607 | 2624 |
|
2608 | 2625 | def perform(self, node, inputs, out_): |
2609 | 2626 | (out,) = out_ |
|
0 commit comments