@@ -82,7 +82,7 @@ for cj in c
8282    energy =  cj[2 ]
8383    s +=  energy
8484end 
85- na =  length (c[1 ][1 ]) #  a# Initialize streaming sampling ################################################ll  conf. have the same no. of atoms
85+ na =  length (c[1 ][1 ]) #  all  conf. have the same no. of atoms
8686avg_energy_per_atom =  s/ n1/ na
8787vref_dict =  Dict (:H  =>  avg_energy_per_atom,
8888                 :C  =>  avg_energy_per_atom,
@@ -101,33 +101,39 @@ for j in 1:n_experiments
101101                           chunksize= m,
102102                           buffersize= 1 ,
103103                           randomized= true )
104-     _ , test_inds =  take! (ch)
104+     cs , test_inds =  take! (ch)
105105    close (ch)
106-     test_inds  =  sort (test_inds) 
107-     test_confs  =   get_confs (test_path, test_inds) 
108-     test_ds  =   calc_descr (test_confs, basis_fitting) 
109-     open ( " test-ds-sme-iso17.jls " ,  " w " )  do  io 
110-      serialize (io, test_ds )
111-      flush (io )
106+     test_confs  =  [] 
107+     for  c  in  cs 
108+         system, energy, forces  =  c 
109+         conf  =   Configuration (system,  Energy (energy), 
110+                               Forces ([ Force (f)  for  f  in  forces]) )
111+          push! (test_confs, conf )
112112    end 
113-     # test_ds = deserialize("test-ds-sme-iso17.jls")
113+     ds_test =  DataSet (test_confs)
114+     ds_test =  calc_descr! (ds_test, basis_fitting)
115+     open (" test-ds-iso17.jls" " w" do  io
116+         serialize (io, ds_test)
117+         flush (io)
118+     end 
119+     # ds_test = deserialize("test-ds-iso17.jls")
114120
115121    for  n in  sample_sizes
116122        #  Sample training dataset using streaming weighted sampling ############
117123        train_inds =  StatsBase. sample (1 : length (ws), Weights (ws), n;
118-                      replace= false , ordered= true )) 
124+                      replace= false , ordered= true )
119125        # Load atomistic configurations
120-         train_confs  =  get_confs (train_path, train_inds)
126+         ds_train  =  get_confs (train_path, read_element , train_inds)
121127        # Adjust reference energies (permanent change)
122-         adjust_energies (train_confs , vref_dict)
128+         adjust_energies!  (ds_train , vref_dict)
123129        #  Compute dataset with energy and force descriptors
124-         train_ds  =  calc_descr (train_confs , basis_fitting)
130+         ds_train  =  calc_descr!  (ds_train , basis_fitting)
125131        #  Create result folder
126132        curr_sampler =  " sws" 
127133        exp_path =  " $res_path /$j -$curr_sampler -n$n /" 
128134        run (` mkdir -p $exp_path ` 
129135        #  Fit and save results
130-         metrics_j =  fit (exp_path, train_ds, test_ds , basis_fitting; vref_dict= vref_dict)
136+         metrics_j =  fit (exp_path, ds_train, ds_test , basis_fitting; vref_dict= vref_dict)
131137        metrics_j =  merge (OrderedDict (" exp_number" =>  j,
132138                                      " method" =>  " $curr_sampler " 
133139                                      " batch_size_prop" =>  n/ N,
@@ -142,17 +148,17 @@ for j in 1:n_experiments
142148        train_inds =  randperm (N)[1 : n]
143149
144150        # Load atomistic configurations
145-         train_confs  =  get_confs (train_path, train_inds)
151+         ds_train  =  get_confs (train_path, read_element , train_inds)
146152        # Adjust reference energies (permanent change)
147-         adjust_energies (train_confs , vref_dict)
153+         adjust_energies!  (ds_train , vref_dict)
148154        #  Compute dataset with energy and force descriptors
149-         train_ds  =  calc_descr (train_confs , basis_fitting)
155+         ds_train  =  calc_descr!  (ds_train , basis_fitting)
150156        #  Create result folder
151157        curr_sampler =  " srs" 
152158        exp_path =  " $res_path /$j -$curr_sampler -n$n /" 
153159        run (` mkdir -p $exp_path ` 
154160        #  Fit and save results
155-         metrics_j =  fit (exp_path, train_ds, test_ds , basis_fitting; vref_dict= vref_dict)
161+         metrics_j =  fit (exp_path, ds_train, ds_test , basis_fitting; vref_dict= vref_dict)
156162        metrics_j =  merge (OrderedDict (" exp_number" =>  j,
157163                                      " method" =>  " $curr_sampler " 
158164                                      " batch_size_prop" =>  n/ N,
0 commit comments