@@ -133,6 +133,13 @@ extern "C" {
133133 dim : c_int ,
134134 nan_val : c_double ,
135135 ) -> c_int ;
136+ fn af_max_ragged (
137+ val_out : * mut af_array ,
138+ idx_out : * mut af_array ,
139+ input : af_array ,
140+ ragged_len : af_array ,
141+ dim : c_int ,
142+ ) -> c_int ;
136143}
137144
138145macro_rules! dim_reduce_func_def {
@@ -1386,3 +1393,63 @@ dim_reduce_by_key_nan_func_def!(
13861393 af_product_by_key_nan,
13871394 ValueType :: ProductOutType
13881395) ;
1396+
1397+ /// Max reduction along given axis as per ragged lengths provided
1398+ ///
1399+ /// # Parameters
1400+ ///
1401+ /// - `input` contains the input values to be reduced
1402+ /// - `ragged_len` array containing number of elements to use when reducing along `dim`
1403+ /// - `dim` is the dimension along which the max operation occurs
1404+ ///
1405+ /// # Return Values
1406+ ///
1407+ /// Tuple of Arrays:
1408+ /// - First element: An Array containing the maximum ragged values in `input` along `dim`
1409+ /// according to `ragged_len`
1410+ /// - Second Element: An Array containing the locations of the maximum ragged values in
1411+ /// `input` along `dim` according to `ragged_len`
1412+ ///
1413+ /// # Examples
1414+ /// ```rust
1415+ /// use arrayfire::{Array, dim4, print, randu, max_ragged};
1416+ /// let vals: [f32; 6] = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
1417+ /// let rlens: [u32; 2] = [9, 2];
1418+ /// let varr = Array::new(&vals, dim4![3, 2]);
1419+ /// let rarr = Array::new(&rlens, dim4![1, 2]);
1420+ /// print(&varr);
1421+ /// // 1 4
1422+ /// // 2 5
1423+ /// // 3 6
1424+ /// print(&rarr); // numbers of elements to participate in reduction along given axis
1425+ /// // 9 2
1426+ /// let (out, idx) = max_ragged(&varr, &rarr, 0);
1427+ /// print(&out);
1428+ /// // 3 5
1429+ /// print(&idx);
1430+ /// // 2 1 //Since 3 is max element for given length 9 along first column
1431+ /// //Since 5 is max element for given length 2 along second column
1432+ /// ```
1433+ pub fn max_ragged < T > (
1434+ input : & Array < T > ,
1435+ ragged_len : & Array < u32 > ,
1436+ dim : i32 ,
1437+ ) -> ( Array < T :: InType > , Array < u32 > )
1438+ where
1439+ T : HasAfEnum ,
1440+ T :: InType : HasAfEnum ,
1441+ {
1442+ unsafe {
1443+ let mut out_vals: af_array = std:: ptr:: null_mut ( ) ;
1444+ let mut out_idxs: af_array = std:: ptr:: null_mut ( ) ;
1445+ let err_val = af_max_ragged (
1446+ & mut out_vals as * mut af_array ,
1447+ & mut out_idxs as * mut af_array ,
1448+ input. get ( ) ,
1449+ ragged_len. get ( ) ,
1450+ dim,
1451+ ) ;
1452+ HANDLE_ERROR ( AfError :: from ( err_val) ) ;
1453+ ( out_vals. into ( ) , out_idxs. into ( ) )
1454+ }
1455+ }
0 commit comments