@@ -518,12 +518,17 @@ where
518518}
519519
520520macro_rules! all_reduce_func_def {
521- ( $doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type : ty ) => {
521+ ( $doc_str: expr, $fn_name: ident, $ffi_name: ident, $assoc_type : ident ) => {
522522 #[ doc=$doc_str]
523- pub fn $fn_name<T >( input: & Array <T >) -> ( $out_type, $out_type)
523+ pub fn $fn_name<T >( input: & Array <T >)
524+ -> (
525+ <<T as HasAfEnum >:: $assoc_type as HasAfEnum >:: BaseType ,
526+ <<T as HasAfEnum >:: $assoc_type as HasAfEnum >:: BaseType
527+ )
524528 where
525529 T : HasAfEnum ,
526- $out_type: HasAfEnum + Fromf64
530+ <T as HasAfEnum >:: $assoc_type: HasAfEnum ,
531+ <<T as HasAfEnum >:: $assoc_type as HasAfEnum >:: BaseType : HasAfEnum + Fromf64 ,
527532 {
528533 let mut real: f64 = 0.0 ;
529534 let mut imag: f64 = 0.0 ;
@@ -533,7 +538,10 @@ macro_rules! all_reduce_func_def {
533538 ) ;
534539 HANDLE_ERROR ( AfError :: from( err_val) ) ;
535540 }
536- ( <$out_type>:: fromf64( real) , <$out_type>:: fromf64( imag) )
541+ (
542+ <<T as HasAfEnum >:: $assoc_type as HasAfEnum >:: BaseType :: fromf64( real) ,
543+ <<T as HasAfEnum >:: $assoc_type as HasAfEnum >:: BaseType :: fromf64( imag) ,
544+ )
537545 }
538546 } ;
539547}
@@ -564,7 +572,7 @@ all_reduce_func_def!(
564572 " ,
565573 sum_all,
566574 af_sum_all,
567- T :: AggregateOutType
575+ AggregateOutType
568576) ;
569577
570578all_reduce_func_def ! (
@@ -594,7 +602,7 @@ all_reduce_func_def!(
594602 " ,
595603 product_all,
596604 af_product_all,
597- T :: ProductOutType
605+ ProductOutType
598606) ;
599607
600608all_reduce_func_def ! (
@@ -623,7 +631,7 @@ all_reduce_func_def!(
623631 " ,
624632 min_all,
625633 af_min_all,
626- T :: InType
634+ InType
627635) ;
628636
629637all_reduce_func_def ! (
@@ -652,10 +660,31 @@ all_reduce_func_def!(
652660 " ,
653661 max_all,
654662 af_max_all,
655- T :: InType
663+ InType
656664) ;
657665
658- all_reduce_func_def ! (
666+ macro_rules! all_reduce_func_def2 {
667+ ( $doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type: ty) => {
668+ #[ doc=$doc_str]
669+ pub fn $fn_name<T >( input: & Array <T >) -> ( $out_type, $out_type)
670+ where
671+ T : HasAfEnum ,
672+ $out_type: HasAfEnum + Fromf64
673+ {
674+ let mut real: f64 = 0.0 ;
675+ let mut imag: f64 = 0.0 ;
676+ unsafe {
677+ let err_val = $ffi_name(
678+ & mut real as * mut c_double, & mut imag as * mut c_double, input. get( ) ,
679+ ) ;
680+ HANDLE_ERROR ( AfError :: from( err_val) ) ;
681+ }
682+ ( <$out_type>:: fromf64( real) , <$out_type>:: fromf64( imag) )
683+ }
684+ } ;
685+ }
686+
687+ all_reduce_func_def2 ! (
659688 "
660689 Find if all values of Array are non-zero
661690
@@ -682,7 +711,7 @@ all_reduce_func_def!(
682711 bool
683712) ;
684713
685- all_reduce_func_def ! (
714+ all_reduce_func_def2 ! (
686715 "
687716 Find if any value of Array is non-zero
688717
@@ -709,7 +738,7 @@ all_reduce_func_def!(
709738 bool
710739) ;
711740
712- all_reduce_func_def ! (
741+ all_reduce_func_def2 ! (
713742 "
714743 Count number of non-zero values in the Array
715744
@@ -751,10 +780,17 @@ all_reduce_func_def!(
751780/// A tuple of summation result.
752781///
753782/// Note: For non-complex data type Arrays, second value of tuple is zero.
754- pub fn sum_nan_all < T > ( input : & Array < T > , val : f64 ) -> ( T :: AggregateOutType , T :: AggregateOutType )
783+ pub fn sum_nan_all < T > (
784+ input : & Array < T > ,
785+ val : f64 ,
786+ ) -> (
787+ <<T as HasAfEnum >:: AggregateOutType as HasAfEnum >:: BaseType ,
788+ <<T as HasAfEnum >:: AggregateOutType as HasAfEnum >:: BaseType ,
789+ )
755790where
756791 T : HasAfEnum ,
757- T :: AggregateOutType : HasAfEnum + Fromf64 ,
792+ <T as HasAfEnum >:: AggregateOutType : HasAfEnum ,
793+ <<T as HasAfEnum >:: AggregateOutType as HasAfEnum >:: BaseType : HasAfEnum + Fromf64 ,
758794{
759795 let mut real: f64 = 0.0 ;
760796 let mut imag: f64 = 0.0 ;
@@ -768,8 +804,8 @@ where
768804 HANDLE_ERROR ( AfError :: from ( err_val) ) ;
769805 }
770806 (
771- <T :: AggregateOutType > :: fromf64 ( real) ,
772- <T :: AggregateOutType > :: fromf64 ( imag) ,
807+ << T as HasAfEnum > :: AggregateOutType as HasAfEnum > :: BaseType :: fromf64 ( real) ,
808+ << T as HasAfEnum > :: AggregateOutType as HasAfEnum > :: BaseType :: fromf64 ( imag) ,
773809 )
774810}
775811
@@ -788,10 +824,17 @@ where
788824/// A tuple of product result.
789825///
790826/// Note: For non-complex data type Arrays, second value of tuple is zero.
791- pub fn product_nan_all < T > ( input : & Array < T > , val : f64 ) -> ( T :: ProductOutType , T :: ProductOutType )
827+ pub fn product_nan_all < T > (
828+ input : & Array < T > ,
829+ val : f64 ,
830+ ) -> (
831+ <<T as HasAfEnum >:: ProductOutType as HasAfEnum >:: BaseType ,
832+ <<T as HasAfEnum >:: ProductOutType as HasAfEnum >:: BaseType ,
833+ )
792834where
793835 T : HasAfEnum ,
794- T :: ProductOutType : HasAfEnum + Fromf64 ,
836+ <T as HasAfEnum >:: ProductOutType : HasAfEnum ,
837+ <<T as HasAfEnum >:: ProductOutType as HasAfEnum >:: BaseType : HasAfEnum + Fromf64 ,
795838{
796839 let mut real: f64 = 0.0 ;
797840 let mut imag: f64 = 0.0 ;
@@ -805,8 +848,8 @@ where
805848 HANDLE_ERROR ( AfError :: from ( err_val) ) ;
806849 }
807850 (
808- <T :: ProductOutType > :: fromf64 ( real) ,
809- <T :: ProductOutType > :: fromf64 ( imag) ,
851+ << T as HasAfEnum > :: ProductOutType as HasAfEnum > :: BaseType :: fromf64 ( real) ,
852+ << T as HasAfEnum > :: ProductOutType as HasAfEnum > :: BaseType :: fromf64 ( imag) ,
810853 )
811854}
812855
@@ -858,12 +901,18 @@ dim_ireduce_func_def!("
858901 " , imax, af_imax, InType ) ;
859902
860903macro_rules! all_ireduce_func_def {
861- ( $doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type : ty ) => {
904+ ( $doc_str: expr, $fn_name: ident, $ffi_name: ident, $assoc_type : ident ) => {
862905 #[ doc=$doc_str]
863- pub fn $fn_name<T >( input: & Array <T >) -> ( $out_type, $out_type, u32 )
906+ pub fn $fn_name<T >( input: & Array <T >)
907+ -> (
908+ <<T as HasAfEnum >:: $assoc_type as HasAfEnum >:: BaseType ,
909+ <<T as HasAfEnum >:: $assoc_type as HasAfEnum >:: BaseType ,
910+ u32
911+ )
864912 where
865913 T : HasAfEnum ,
866- $out_type: HasAfEnum + Fromf64
914+ <T as HasAfEnum >:: $assoc_type: HasAfEnum ,
915+ <<T as HasAfEnum >:: $assoc_type as HasAfEnum >:: BaseType : HasAfEnum + Fromf64 ,
867916 {
868917 let mut real: f64 = 0.0 ;
869918 let mut imag: f64 = 0.0 ;
@@ -875,7 +924,11 @@ macro_rules! all_ireduce_func_def {
875924 ) ;
876925 HANDLE_ERROR ( AfError :: from( err_val) ) ;
877926 }
878- ( <$out_type>:: fromf64( real) , <$out_type>:: fromf64( imag) , temp)
927+ (
928+ <<T as HasAfEnum >:: $assoc_type as HasAfEnum >:: BaseType :: fromf64( real) ,
929+ <<T as HasAfEnum >:: $assoc_type as HasAfEnum >:: BaseType :: fromf64( imag) ,
930+ temp,
931+ )
879932 }
880933 } ;
881934}
@@ -898,7 +951,7 @@ all_ireduce_func_def!(
898951 " ,
899952 imin_all,
900953 af_imin_all,
901- T :: InType
954+ InType
902955) ;
903956all_ireduce_func_def ! (
904957 "
@@ -918,7 +971,7 @@ all_ireduce_func_def!(
918971 " ,
919972 imax_all,
920973 af_imax_all,
921- T :: InType
974+ InType
922975) ;
923976
924977/// Locate the indices of non-zero elements.
@@ -1386,3 +1439,40 @@ dim_reduce_by_key_nan_func_def!(
13861439 af_product_by_key_nan,
13871440 ValueType :: ProductOutType
13881441) ;
1442+
1443+ #[ cfg( test) ]
1444+ mod tests {
1445+ use super :: super :: core:: c32;
1446+ use super :: { product_nan_all, sum_all, sum_nan_all, imin_all, imax_all} ;
1447+ use crate :: randu;
1448+
1449+ #[ test]
1450+ fn all_reduce_api ( ) {
1451+ let a = randu ! ( c32; 10 , 10 ) ;
1452+ println ! ( "Reduction of complex f32 matrix: {:?}" , sum_all( & a) ) ;
1453+
1454+ let b = randu ! ( bool ; 10 , 10 ) ;
1455+ println ! ( "reduction of bool matrix: {:?}" , sum_all( & b) ) ;
1456+
1457+ println ! (
1458+ "reduction of complex f32 matrix after replacing nan with {}: {:?}" ,
1459+ 1.0 ,
1460+ product_nan_all( & a, 1.0 )
1461+ ) ;
1462+
1463+ println ! (
1464+ "reduction of bool matrix after replacing nan with {}: {:?}" ,
1465+ 0.0 ,
1466+ sum_nan_all( & b, 0.0 )
1467+ ) ;
1468+ }
1469+
1470+ #[ test]
1471+ fn all_ireduce_api ( ) {
1472+ let a = randu ! ( c32; 10 ) ;
1473+ println ! ( "Reduction of complex f32 matrix: {:?}" , imin_all( & a) ) ;
1474+
1475+ let b = randu ! ( u32 ; 10 ) ;
1476+ println ! ( "reduction of bool matrix: {:?}" , imax_all( & b) ) ;
1477+ }
1478+ }
0 commit comments