@@ -4,7 +4,16 @@ use super::error::HANDLE_ERROR;
44use super :: util:: { af_array, dim_t, void_ptr, HasAfEnum } ;
55
66use libc:: { c_char, c_int, c_longlong, c_uint, c_void} ;
7+ #[ cfg( feature = "afserde" ) ]
8+ use serde:: de:: { Deserializer , Error , Unexpected } ;
9+ #[ cfg( feature = "afserde" ) ]
10+ use serde:: ser:: Serializer ;
11+ #[ cfg( feature = "afserde" ) ]
12+ use serde:: { Deserialize , Serialize } ;
13+ use std:: clone:: Clone ;
14+ use std:: default:: Default ;
715use std:: ffi:: CString ;
16+ use std:: fmt:: Debug ;
817use std:: marker:: PhantomData ;
918
1019// Some unused functions from array.h in C-API of ArrayFire
@@ -851,12 +860,73 @@ pub fn is_eval_manual() -> bool {
851860 }
852861}
853862
863+ #[ cfg( feature = "afserde" ) ]
864+ #[ derive( Debug , Serialize , Deserialize ) ]
865+ struct ArrayOnHost < T : HasAfEnum + Debug > {
866+ dtype : DType ,
867+ shape : Dim4 ,
868+ data : Vec < T > ,
869+ }
870+
871+ /// Serialize Implementation of Array
872+ #[ cfg( feature = "afserde" ) ]
873+ impl < T > Serialize for Array < T >
874+ where
875+ T : Default + Clone + Serialize + HasAfEnum + Debug ,
876+ {
877+ fn serialize < S > ( & self , serializer : S ) -> Result < S :: Ok , S :: Error >
878+ where
879+ S : Serializer ,
880+ {
881+ let mut vec = vec ! [ T :: default ( ) ; self . elements( ) ] ;
882+ self . host ( & mut vec) ;
883+ let arr_on_host = ArrayOnHost {
884+ dtype : self . get_type ( ) ,
885+ shape : self . dims ( ) . clone ( ) ,
886+ data : vec,
887+ } ;
888+ arr_on_host. serialize ( serializer)
889+ }
890+ }
891+
892+ /// Deserialize Implementation of Array
893+ #[ cfg( feature = "afserde" ) ]
894+ impl < ' de , T > Deserialize < ' de > for Array < T >
895+ where
896+ T : Deserialize < ' de > + HasAfEnum + Debug ,
897+ {
898+ fn deserialize < D > ( deserializer : D ) -> Result < Self , D :: Error >
899+ where
900+ D : Deserializer < ' de > ,
901+ {
902+ match ArrayOnHost :: < T > :: deserialize ( deserializer) {
903+ Ok ( arr_on_host) => {
904+ let read_dtype = arr_on_host. dtype ;
905+ let expected_dtype = T :: get_af_dtype ( ) ;
906+ if expected_dtype != read_dtype {
907+ let error_msg = format ! (
908+ "data type is {:?}, deserialized type is {:?}" ,
909+ expected_dtype, read_dtype
910+ ) ;
911+ return Err ( Error :: invalid_value ( Unexpected :: Enum , & error_msg. as_str ( ) ) ) ;
912+ }
913+ Ok ( Array :: < T > :: new (
914+ & arr_on_host. data ,
915+ arr_on_host. shape . clone ( ) ,
916+ ) )
917+ }
918+ Err ( err) => Err ( err) ,
919+ }
920+ }
921+ }
922+
854923#[ cfg( test) ]
855924mod tests {
925+ use super :: super :: super :: algorithm:: sum_all;
856926 use super :: super :: array:: print;
857927 use super :: super :: data:: constant;
858928 use super :: super :: device:: { info, set_device, sync} ;
859- use crate :: dim4;
929+ use crate :: { dim4, randu } ;
860930 use std:: sync:: { mpsc, Arc , RwLock } ;
861931 use std:: thread;
862932
@@ -1082,4 +1152,36 @@ mod tests {
10821152 // 8.0000 8.0000 8.0000
10831153 // ANCHOR_END: accum_using_channel
10841154 }
1155+
1156+ #[ test]
1157+ #[ cfg( feature = "afserde" ) ]
1158+ fn array_serde_json ( ) {
1159+ use super :: Array ;
1160+
1161+ let input = randu ! ( u8 ; 2 , 2 ) ;
1162+ let serd = match serde_json:: to_string ( & input) {
1163+ Ok ( serialized_str) => serialized_str,
1164+ Err ( e) => e. to_string ( ) ,
1165+ } ;
1166+
1167+ let deserd: Array < u8 > = serde_json:: from_str ( & serd) . unwrap ( ) ;
1168+
1169+ assert_eq ! ( sum_all( & ( input - deserd) ) , ( 0u32 , 0u32 ) ) ;
1170+ }
1171+
1172+ #[ test]
1173+ #[ cfg( feature = "afserde" ) ]
1174+ fn array_serde_bincode ( ) {
1175+ use super :: Array ;
1176+
1177+ let input = randu ! ( u8 ; 2 , 2 ) ;
1178+ let encoded = match bincode:: serialize ( & input) {
1179+ Ok ( encoded) => encoded,
1180+ Err ( _) => vec ! [ ] ,
1181+ } ;
1182+
1183+ let decoded: Array < u8 > = bincode:: deserialize ( & encoded) . unwrap ( ) ;
1184+
1185+ assert_eq ! ( sum_all( & ( input - decoded) ) , ( 0u32 , 0u32 ) ) ;
1186+ }
10851187}
0 commit comments