@@ -5,7 +5,7 @@ use crate::array::Array;
55use crate :: defines:: { AfError , BinaryOp } ;
66use crate :: error:: HANDLE_ERROR ;
77use crate :: util:: { AfArray , MutAfArray , MutDouble , MutUint } ;
8- use crate :: util:: { HasAfEnum , RealNumber , Scanable } ;
8+ use crate :: util:: { HasAfEnum , RealNumber , ReduceByKeyInput , Scanable } ;
99
1010#[ allow( dead_code) ]
1111extern "C" {
@@ -59,6 +59,71 @@ extern "C" {
5959 op : c_uint ,
6060 inclusive : c_int ,
6161 ) -> c_int ;
62+ fn af_all_true_by_key (
63+ keys_out : MutAfArray ,
64+ vals_out : MutAfArray ,
65+ keys : AfArray ,
66+ vals : AfArray ,
67+ dim : c_int ,
68+ ) -> c_int ;
69+ fn af_any_true_by_key (
70+ keys_out : MutAfArray ,
71+ vals_out : MutAfArray ,
72+ keys : AfArray ,
73+ vals : AfArray ,
74+ dim : c_int ,
75+ ) -> c_int ;
76+ fn af_count_by_key (
77+ keys_out : MutAfArray ,
78+ vals_out : MutAfArray ,
79+ keys : AfArray ,
80+ vals : AfArray ,
81+ dim : c_int ,
82+ ) -> c_int ;
83+ fn af_max_by_key (
84+ keys_out : MutAfArray ,
85+ vals_out : MutAfArray ,
86+ keys : AfArray ,
87+ vals : AfArray ,
88+ dim : c_int ,
89+ ) -> c_int ;
90+ fn af_min_by_key (
91+ keys_out : MutAfArray ,
92+ vals_out : MutAfArray ,
93+ keys : AfArray ,
94+ vals : AfArray ,
95+ dim : c_int ,
96+ ) -> c_int ;
97+ fn af_product_by_key (
98+ keys_out : MutAfArray ,
99+ vals_out : MutAfArray ,
100+ keys : AfArray ,
101+ vals : AfArray ,
102+ dim : c_int ,
103+ ) -> c_int ;
104+ fn af_product_by_key_nan (
105+ keys_out : MutAfArray ,
106+ vals_out : MutAfArray ,
107+ keys : AfArray ,
108+ vals : AfArray ,
109+ dim : c_int ,
110+ nan_val : c_double ,
111+ ) -> c_int ;
112+ fn af_sum_by_key (
113+ keys_out : MutAfArray ,
114+ vals_out : MutAfArray ,
115+ keys : AfArray ,
116+ vals : AfArray ,
117+ dim : c_int ,
118+ ) -> c_int ;
119+ fn af_sum_by_key_nan (
120+ keys_out : MutAfArray ,
121+ vals_out : MutAfArray ,
122+ keys : AfArray ,
123+ vals : AfArray ,
124+ dim : c_int ,
125+ nan_val : c_double ,
126+ ) -> c_int ;
62127}
63128
64129macro_rules! dim_reduce_func_def {
@@ -1137,3 +1202,193 @@ where
11371202 }
11381203 temp. into ( )
11391204}
1205+
1206+ macro_rules! dim_reduce_by_key_func_def {
1207+ ( $brief_str: expr, $ex_str: expr, $fn_name: ident, $ffi_name: ident, $out_type: ty) => {
1208+ #[ doc=$brief_str]
1209+ /// # Parameters
1210+ ///
1211+ /// - `keys` - key Array
1212+ /// - `vals` - value Array
1213+ /// - `dim` - Dimension along which the input Array is reduced
1214+ ///
1215+ /// # Return Values
1216+ ///
1217+ /// Tuple of Arrays, with output keys and values after reduction
1218+ ///
1219+ #[ doc=$ex_str]
1220+ pub fn $fn_name<KeyType , ValueType >( keys: & Array <KeyType >, vals: & Array <ValueType >,
1221+ dim: i32
1222+ ) -> ( Array <KeyType >, Array <$out_type>)
1223+ where
1224+ KeyType : ReduceByKeyInput ,
1225+ ValueType : HasAfEnum ,
1226+ $out_type: HasAfEnum ,
1227+ {
1228+ let mut out_keys: i64 = 0 ;
1229+ let mut out_vals: i64 = 0 ;
1230+ unsafe {
1231+ let err_val = $ffi_name(
1232+ & mut out_keys as MutAfArray ,
1233+ & mut out_vals as MutAfArray ,
1234+ keys. get( ) as AfArray ,
1235+ vals. get( ) as AfArray ,
1236+ dim as c_int,
1237+ ) ;
1238+ HANDLE_ERROR ( AfError :: from( err_val) ) ;
1239+ }
1240+ ( out_keys. into( ) , out_vals. into( ) )
1241+ }
1242+ } ;
1243+ }
1244+
1245+ dim_reduce_by_key_func_def ! (
1246+ "
1247+ Key based AND of elements along a given dimension
1248+
1249+ All positive non-zero values are considered true, while negative and zero
1250+ values are considered as false.
1251+ " ,
1252+ "
1253+ # Examples
1254+ ```rust
1255+ use arrayfire::{Dim4, print, randu, all_true_by_key};
1256+ let dims = Dim4::new(&[5, 3, 1, 1]);
1257+ let vals = randu::<f32>(dims);
1258+ let keys = randu::<u32>(Dim4::new(&[5, 1, 1, 1]));
1259+ print(&vals);
1260+ print(&keys);
1261+ let (out_keys, out_vals) = all_true_by_key(&keys, &vals, 0);
1262+ print(&out_keys);
1263+ print(&out_vals);
1264+ ```
1265+ " ,
1266+ all_true_by_key,
1267+ af_all_true_by_key,
1268+ ValueType :: AggregateOutType
1269+ ) ;
1270+
1271+ dim_reduce_by_key_func_def ! (
1272+ "
1273+ Key based OR of elements along a given dimension
1274+
1275+ All positive non-zero values are considered true, while negative and zero
1276+ values are considered as false.
1277+ " ,
1278+ "
1279+ # Examples
1280+ ```rust
1281+ use arrayfire::{Dim4, print, randu, any_true_by_key};
1282+ let dims = Dim4::new(&[5, 3, 1, 1]);
1283+ let vals = randu::<f32>(dims);
1284+ let keys = randu::<u32>(Dim4::new(&[5, 1, 1, 1]));
1285+ print(&vals);
1286+ print(&keys);
1287+ let (out_keys, out_vals) = any_true_by_key(&keys, &vals, 0);
1288+ print(&out_keys);
1289+ print(&out_vals);
1290+ ```
1291+ " ,
1292+ any_true_by_key,
1293+ af_any_true_by_key,
1294+ ValueType :: AggregateOutType
1295+ ) ;
1296+
1297+ dim_reduce_by_key_func_def ! (
1298+ "Find total count of elements with similar keys along a given dimension" ,
1299+ "" ,
1300+ count_by_key,
1301+ af_count_by_key,
1302+ ValueType :: AggregateOutType
1303+ ) ;
1304+
1305+ dim_reduce_by_key_func_def ! (
1306+ "Find maximum among values of similar keys along a given dimension" ,
1307+ "" ,
1308+ max_by_key,
1309+ af_max_by_key,
1310+ ValueType :: AggregateOutType
1311+ ) ;
1312+
1313+ dim_reduce_by_key_func_def ! (
1314+ "Find minimum among values of similar keys along a given dimension" ,
1315+ "" ,
1316+ min_by_key,
1317+ af_min_by_key,
1318+ ValueType :: AggregateOutType
1319+ ) ;
1320+
1321+ dim_reduce_by_key_func_def ! (
1322+ "Find product of all values with similar keys along a given dimension" ,
1323+ "" ,
1324+ product_by_key,
1325+ af_product_by_key,
1326+ ValueType :: ProductOutType
1327+ ) ;
1328+
1329+ dim_reduce_by_key_func_def ! (
1330+ "Find sum of all values with similar keys along a given dimension" ,
1331+ "" ,
1332+ sum_by_key,
1333+ af_sum_by_key,
1334+ ValueType :: AggregateOutType
1335+ ) ;
1336+
1337+ macro_rules! dim_reduce_by_key_nan_func_def {
1338+ ( $brief_str: expr, $ex_str: expr, $fn_name: ident, $ffi_name: ident, $out_type: ty) => {
1339+ #[ doc=$brief_str]
1340+ ///
1341+ /// This version of sum by key can replaced all NaN values in the input
1342+ /// with a user provided value before performing the reduction operation.
1343+ /// # Parameters
1344+ ///
1345+ /// - `keys` - key Array
1346+ /// - `vals` - value Array
1347+ /// - `dim` - Dimension along which the input Array is reduced
1348+ ///
1349+ /// # Return Values
1350+ ///
1351+ /// Tuple of Arrays, with output keys and values after reduction
1352+ ///
1353+ #[ doc=$ex_str]
1354+ pub fn $fn_name<KeyType , ValueType >( keys: & Array <KeyType >, vals: & Array <ValueType >,
1355+ dim: i32 , replace_value: f64
1356+ ) -> ( Array <KeyType >, Array <$out_type>)
1357+ where
1358+ KeyType : ReduceByKeyInput ,
1359+ ValueType : HasAfEnum ,
1360+ $out_type: HasAfEnum ,
1361+ {
1362+ let mut out_keys: i64 = 0 ;
1363+ let mut out_vals: i64 = 0 ;
1364+ unsafe {
1365+ let err_val = $ffi_name(
1366+ & mut out_keys as MutAfArray ,
1367+ & mut out_vals as MutAfArray ,
1368+ keys. get( ) as AfArray ,
1369+ vals. get( ) as AfArray ,
1370+ dim as c_int,
1371+ replace_value as c_double,
1372+ ) ;
1373+ HANDLE_ERROR ( AfError :: from( err_val) ) ;
1374+ }
1375+ ( out_keys. into( ) , out_vals. into( ) )
1376+ }
1377+ } ;
1378+ }
1379+
1380+ dim_reduce_by_key_nan_func_def ! (
1381+ "Compute sum of all values with similar keys along a given dimension" ,
1382+ "" ,
1383+ sum_by_key_nan,
1384+ af_sum_by_key_nan,
1385+ ValueType :: AggregateOutType
1386+ ) ;
1387+
1388+ dim_reduce_by_key_nan_func_def ! (
1389+ "Compute product of all values with similar keys along a given dimension" ,
1390+ "" ,
1391+ product_by_key_nan,
1392+ af_product_by_key_nan,
1393+ ValueType :: ProductOutType
1394+ ) ;
0 commit comments