11use super :: core:: {
2- af_array, AfError , Array , BinaryOp , HasAfEnum , RealNumber , ReduceByKeyInput , Scanable ,
2+ af_array, AfError , Array , BinaryOp , Fromf64 , HasAfEnum , RealNumber , ReduceByKeyInput , Scanable ,
33 HANDLE_ERROR ,
44} ;
55
@@ -518,9 +518,13 @@ where
518518}
519519
520520macro_rules! all_reduce_func_def {
521- ( $doc_str: expr, $fn_name: ident, $ffi_name: ident) => {
521+ ( $doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type : ty ) => {
522522 #[ doc=$doc_str]
523- pub fn $fn_name<T : HasAfEnum >( input: & Array <T >) -> ( f64 , f64 ) {
523+ pub fn $fn_name<T >( input: & Array <T >) -> ( $out_type, $out_type)
524+ where
525+ T : HasAfEnum ,
526+ $out_type: HasAfEnum + Fromf64
527+ {
524528 let mut real: f64 = 0.0 ;
525529 let mut imag: f64 = 0.0 ;
526530 unsafe {
@@ -529,7 +533,7 @@ macro_rules! all_reduce_func_def {
529533 ) ;
530534 HANDLE_ERROR ( AfError :: from( err_val) ) ;
531535 }
532- ( real, imag)
536+ ( <$out_type> :: fromf64 ( real) , <$out_type> :: fromf64 ( imag) )
533537 }
534538 } ;
535539}
@@ -559,7 +563,8 @@ all_reduce_func_def!(
559563 ```
560564 " ,
561565 sum_all,
562- af_sum_all
566+ af_sum_all,
567+ T :: AggregateOutType
563568) ;
564569
565570all_reduce_func_def ! (
@@ -588,7 +593,8 @@ all_reduce_func_def!(
588593 ```
589594 " ,
590595 product_all,
591- af_product_all
596+ af_product_all,
597+ T :: ProductOutType
592598) ;
593599
594600all_reduce_func_def ! (
@@ -616,7 +622,8 @@ all_reduce_func_def!(
616622 ```
617623 " ,
618624 min_all,
619- af_min_all
625+ af_min_all,
626+ T :: InType
620627) ;
621628
622629all_reduce_func_def ! (
@@ -644,7 +651,8 @@ all_reduce_func_def!(
644651 ```
645652 " ,
646653 max_all,
647- af_max_all
654+ af_max_all,
655+ T :: InType
648656) ;
649657
650658all_reduce_func_def ! (
@@ -670,7 +678,8 @@ all_reduce_func_def!(
670678 ```
671679 " ,
672680 all_true_all,
673- af_all_true_all
681+ af_all_true_all,
682+ bool
674683) ;
675684
676685all_reduce_func_def ! (
@@ -696,7 +705,8 @@ all_reduce_func_def!(
696705 ```
697706 " ,
698707 any_true_all,
699- af_any_true_all
708+ af_any_true_all,
709+ bool
700710) ;
701711
702712all_reduce_func_def ! (
@@ -722,7 +732,8 @@ all_reduce_func_def!(
722732 ```
723733 " ,
724734 count_all,
725- af_count_all
735+ af_count_all,
736+ u64
726737) ;
727738
728739/// Sum all values using user provided value for `NAN`
@@ -740,7 +751,11 @@ all_reduce_func_def!(
740751/// A tuple of summation result.
741752///
742753/// Note: For non-complex data type Arrays, second value of tuple is zero.
743- pub fn sum_nan_all < T : HasAfEnum > ( input : & Array < T > , val : f64 ) -> ( f64 , f64 ) {
754+ pub fn sum_nan_all < T > ( input : & Array < T > , val : f64 ) -> ( T :: AggregateOutType , T :: AggregateOutType )
755+ where
756+ T : HasAfEnum ,
757+ T :: AggregateOutType : HasAfEnum + Fromf64 ,
758+ {
744759 let mut real: f64 = 0.0 ;
745760 let mut imag: f64 = 0.0 ;
746761 unsafe {
@@ -752,7 +767,10 @@ pub fn sum_nan_all<T: HasAfEnum>(input: &Array<T>, val: f64) -> (f64, f64) {
752767 ) ;
753768 HANDLE_ERROR ( AfError :: from ( err_val) ) ;
754769 }
755- ( real, imag)
770+ (
771+ <T :: AggregateOutType >:: fromf64 ( real) ,
772+ <T :: AggregateOutType >:: fromf64 ( imag) ,
773+ )
756774}
757775
758776/// Product of all values using user provided value for `NAN`
@@ -770,7 +788,11 @@ pub fn sum_nan_all<T: HasAfEnum>(input: &Array<T>, val: f64) -> (f64, f64) {
770788/// A tuple of product result.
771789///
772790/// Note: For non-complex data type Arrays, second value of tuple is zero.
773- pub fn product_nan_all < T : HasAfEnum > ( input : & Array < T > , val : f64 ) -> ( f64 , f64 ) {
791+ pub fn product_nan_all < T > ( input : & Array < T > , val : f64 ) -> ( T :: ProductOutType , T :: ProductOutType )
792+ where
793+ T : HasAfEnum ,
794+ T :: ProductOutType : HasAfEnum + Fromf64 ,
795+ {
774796 let mut real: f64 = 0.0 ;
775797 let mut imag: f64 = 0.0 ;
776798 unsafe {
@@ -782,7 +804,10 @@ pub fn product_nan_all<T: HasAfEnum>(input: &Array<T>, val: f64) -> (f64, f64) {
782804 ) ;
783805 HANDLE_ERROR ( AfError :: from ( err_val) ) ;
784806 }
785- ( real, imag)
807+ (
808+ <T :: ProductOutType >:: fromf64 ( real) ,
809+ <T :: ProductOutType >:: fromf64 ( imag) ,
810+ )
786811}
787812
788813macro_rules! dim_ireduce_func_def {
@@ -833,9 +858,13 @@ dim_ireduce_func_def!("
833858 " , imax, af_imax, InType ) ;
834859
835860macro_rules! all_ireduce_func_def {
836- ( $doc_str: expr, $fn_name: ident, $ffi_name: ident) => {
861+ ( $doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type : ty ) => {
837862 #[ doc=$doc_str]
838- pub fn $fn_name<T : HasAfEnum >( input: & Array <T >) -> ( f64 , f64 , u32 ) {
863+ pub fn $fn_name<T >( input: & Array <T >) -> ( $out_type, $out_type, u32 )
864+ where
865+ T : HasAfEnum ,
866+ $out_type: HasAfEnum + Fromf64
867+ {
839868 let mut real: f64 = 0.0 ;
840869 let mut imag: f64 = 0.0 ;
841870 let mut temp: u32 = 0 ;
@@ -846,7 +875,7 @@ macro_rules! all_ireduce_func_def {
846875 ) ;
847876 HANDLE_ERROR ( AfError :: from( err_val) ) ;
848877 }
849- ( real, imag, temp)
878+ ( <$out_type> :: fromf64 ( real) , <$out_type> :: fromf64 ( imag) , temp)
850879 }
851880 } ;
852881}
@@ -868,7 +897,8 @@ all_ireduce_func_def!(
868897 * index of minimum element in the third component.
869898 " ,
870899 imin_all,
871- af_imin_all
900+ af_imin_all,
901+ T :: InType
872902) ;
873903all_ireduce_func_def ! (
874904 "
@@ -887,7 +917,8 @@ all_ireduce_func_def!(
887917 - index of maximum element in the third component.
888918 " ,
889919 imax_all,
890- af_imax_all
920+ af_imax_all,
921+ T :: InType
891922) ;
892923
893924/// Locate the indices of non-zero elements.
0 commit comments