@@ -4,7 +4,16 @@ use super::error::HANDLE_ERROR;
44use super :: util:: { af_array, dim_t, void_ptr, HasAfEnum } ;
55
66use libc:: { c_char, c_int, c_longlong, c_uint, c_void} ;
7+ #[ cfg( feature = "afserde" ) ]
8+ use serde:: de:: { Deserializer , Error , Unexpected } ;
9+ #[ cfg( feature = "afserde" ) ]
10+ use serde:: ser:: Serializer ;
11+ #[ cfg( feature = "afserde" ) ]
12+ use serde:: { Deserialize , Serialize } ;
13+ use std:: clone:: Clone ;
14+ use std:: default:: Default ;
715use std:: ffi:: CString ;
16+ use std:: fmt:: Debug ;
817use std:: marker:: PhantomData ;
918
1019// Some unused functions from array.h in C-API of ArrayFire
@@ -851,12 +860,72 @@ pub fn is_eval_manual() -> bool {
851860 }
852861}
853862
863+ #[ derive( Debug , Serialize , Deserialize ) ]
864+ struct ArrayOnHost < T : HasAfEnum + Debug > {
865+ dtype : DType ,
866+ shape : Dim4 ,
867+ data : Vec < T > ,
868+ }
869+
870+ /// Serialize Implementation of Array
871+ #[ cfg( feature = "afserde" ) ]
872+ impl < T > Serialize for Array < T >
873+ where
874+ T : Default + Clone + Serialize + HasAfEnum + Debug ,
875+ {
876+ fn serialize < S > ( & self , serializer : S ) -> Result < S :: Ok , S :: Error >
877+ where
878+ S : Serializer ,
879+ {
880+ let mut vec = vec ! [ T :: default ( ) ; self . elements( ) ] ;
881+ self . host ( & mut vec) ;
882+ let arr_on_host = ArrayOnHost {
883+ dtype : self . get_type ( ) ,
884+ shape : self . dims ( ) . clone ( ) ,
885+ data : vec,
886+ } ;
887+ arr_on_host. serialize ( serializer)
888+ }
889+ }
890+
891+ /// Deserialize Implementation of Array
892+ #[ cfg( feature = "afserde" ) ]
893+ impl < ' de , T > Deserialize < ' de > for Array < T >
894+ where
895+ T : Deserialize < ' de > + HasAfEnum + Debug ,
896+ {
897+ fn deserialize < D > ( deserializer : D ) -> Result < Self , D :: Error >
898+ where
899+ D : Deserializer < ' de > ,
900+ {
901+ match ArrayOnHost :: < T > :: deserialize ( deserializer) {
902+ Ok ( arr_on_host) => {
903+ let read_dtype = arr_on_host. dtype ;
904+ let expected_dtype = T :: get_af_dtype ( ) ;
905+ if expected_dtype != read_dtype {
906+ let error_msg = format ! (
907+ "data type is {:?}, deserialized type is {:?}" ,
908+ expected_dtype, read_dtype
909+ ) ;
910+ return Err ( Error :: invalid_value ( Unexpected :: Enum , & error_msg. as_str ( ) ) ) ;
911+ }
912+ Ok ( Array :: < T > :: new (
913+ & arr_on_host. data ,
914+ arr_on_host. shape . clone ( ) ,
915+ ) )
916+ }
917+ Err ( err) => Err ( err) ,
918+ }
919+ }
920+ }
921+
854922#[ cfg( test) ]
855923mod tests {
924+ use super :: super :: super :: algorithm:: sum_all;
856925 use super :: super :: array:: print;
857926 use super :: super :: data:: constant;
858927 use super :: super :: device:: { info, set_device, sync} ;
859- use crate :: dim4;
928+ use crate :: { dim4, randu } ;
860929 use std:: sync:: { mpsc, Arc , RwLock } ;
861930 use std:: thread;
862931
@@ -1082,4 +1151,36 @@ mod tests {
10821151 // 8.0000 8.0000 8.0000
10831152 // ANCHOR_END: accum_using_channel
10841153 }
1154+
1155+ #[ test]
1156+ #[ cfg( feature = "afserde" ) ]
1157+ fn array_serde_json ( ) {
1158+ use super :: Array ;
1159+
1160+ let input = randu ! ( u8 ; 2 , 2 ) ;
1161+ let serd = match serde_json:: to_string ( & input) {
1162+ Ok ( serialized_str) => serialized_str,
1163+ Err ( e) => e. to_string ( ) ,
1164+ } ;
1165+
1166+ let deserd: Array < u8 > = serde_json:: from_str ( & serd) . unwrap ( ) ;
1167+
1168+ assert_eq ! ( sum_all( & ( input - deserd) ) , ( 0u32 , 0u32 ) ) ;
1169+ }
1170+
1171+ #[ test]
1172+ #[ cfg( feature = "afserde" ) ]
1173+ fn array_serde_bincode ( ) {
1174+ use super :: Array ;
1175+
1176+ let input = randu ! ( u8 ; 2 , 2 ) ;
1177+ let encoded = match bincode:: serialize ( & input) {
1178+ Ok ( encoded) => encoded,
1179+ Err ( _) => vec ! [ ] ,
1180+ } ;
1181+
1182+ let decoded: Array < u8 > = bincode:: deserialize ( & encoded) . unwrap ( ) ;
1183+
1184+ assert_eq ! ( sum_all( & ( input - decoded) ) , ( 0u32 , 0u32 ) ) ;
1185+ }
10851186}
0 commit comments