@@ -654,7 +654,7 @@ def _find_max_under_constraint(self, constrained, dependent, predicate):
654654 Args:
655655 constrained: Over these values the constraint is specified. A rank-1
656656 tensor.
657- dependent: From these values the maximum that satiesfies the
657+ dependent: From these values the maximum that satisfies the
658658 constraint is selected. Values in this tensor and in
659659 `constrained` are linked by having the same threshold at each
660660 position, hence this tensor must have the same shape.
@@ -664,11 +664,12 @@ def _find_max_under_constraint(self, constrained, dependent, predicate):
664664 Returns:
665665 maximal dependent value, if no value satisfies the constraint 0.0.
666666 """
667- feasible = ops .nonzero (predicate (constrained , self .value ))
668- feasible_exists = ops .greater (ops .size (feasible ), 0 )
669- max_dependent = ops .max (ops .take (dependent , feasible ), initial = 0 )
670-
671- return ops .where (feasible_exists , max_dependent , 0.0 )
667+ feasible = predicate (constrained , self .value )
668+ # Mask values based on whether they satisfy the constraint and take max.
669+ return ops .max (
670+ ops .multiply (dependent , ops .cast (feasible , dependent .dtype )),
671+ initial = 0 ,
672+ )
672673
673674
674675@keras_export ("keras.metrics.SensitivityAtSpecificity" )
0 commit comments