Skip to content

Commit b755848

Browse files
committed
Update of iso17 example
1 parent db3fd06 commit b755848

File tree

1 file changed

+24
-18
lines changed

1 file changed

+24
-18
lines changed

examples/atomistic/srs-vs-sme-iso17.jl

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ for cj in c
8282
energy = cj[2]
8383
s += energy
8484
end
85-
na = length(c[1][1]) # a# Initialize streaming sampling ################################################ll conf. have the same no. of atoms
85+
na = length(c[1][1]) # all conf. have the same no. of atoms
8686
avg_energy_per_atom = s/n1/na
8787
vref_dict = Dict(:H => avg_energy_per_atom,
8888
:C => avg_energy_per_atom,
@@ -101,33 +101,39 @@ for j in 1:n_experiments
101101
chunksize=m,
102102
buffersize=1,
103103
randomized=true)
104-
_, test_inds = take!(ch)
104+
cs, test_inds = take!(ch)
105105
close(ch)
106-
test_inds = sort(test_inds)
107-
test_confs = get_confs(test_path, test_inds)
108-
test_ds = calc_descr(test_confs, basis_fitting)
109-
open("test-ds-sme-iso17.jls", "w") do io
110-
serialize(io, test_ds)
111-
flush(io)
106+
test_confs = []
107+
for c in cs
108+
system, energy, forces = c
109+
conf = Configuration(system, Energy(energy),
110+
Forces([Force(f) for f in forces]))
111+
push!(test_confs, conf)
112112
end
113-
#test_ds = deserialize("test-ds-sme-iso17.jls")
113+
ds_test = DataSet(test_confs)
114+
ds_test = calc_descr!(ds_test, basis_fitting)
115+
open("test-ds-iso17.jls", "w") do io
116+
serialize(io, ds_test)
117+
flush(io)
118+
end
119+
#ds_test = deserialize("test-ds-iso17.jls")
114120

115121
for n in sample_sizes
116122
# Sample training dataset using streaming weighted sampling ############
117123
train_inds = StatsBase.sample(1:length(ws), Weights(ws), n;
118-
replace=false, ordered=true))
124+
replace=false, ordered=true)
119125
#Load atomistic configurations
120-
train_confs = get_confs(train_path, train_inds)
126+
ds_train = get_confs(train_path, read_element, train_inds)
121127
#Adjust reference energies (permanent change)
122-
adjust_energies(train_confs, vref_dict)
128+
adjust_energies!(ds_train, vref_dict)
123129
# Compute dataset with energy and force descriptors
124-
train_ds = calc_descr(train_confs, basis_fitting)
130+
ds_train = calc_descr!(ds_train, basis_fitting)
125131
# Create result folder
126132
curr_sampler = "sws"
127133
exp_path = "$res_path/$j-$curr_sampler-n$n/"
128134
run(`mkdir -p $exp_path`)
129135
# Fit and save results
130-
metrics_j = fit(exp_path, train_ds, test_ds, basis_fitting; vref_dict=vref_dict)
136+
metrics_j = fit(exp_path, ds_train, ds_test, basis_fitting; vref_dict=vref_dict)
131137
metrics_j = merge(OrderedDict("exp_number" => j,
132138
"method" => "$curr_sampler",
133139
"batch_size_prop" => n/N,
@@ -142,17 +148,17 @@ for j in 1:n_experiments
142148
train_inds = randperm(N)[1:n]
143149

144150
#Load atomistic configurations
145-
train_confs = get_confs(train_path, train_inds)
151+
ds_train = get_confs(train_path, read_element, train_inds)
146152
#Adjust reference energies (permanent change)
147-
adjust_energies(train_confs, vref_dict)
153+
adjust_energies!(ds_train, vref_dict)
148154
# Compute dataset with energy and force descriptors
149-
train_ds = calc_descr(train_confs, basis_fitting)
155+
ds_train = calc_descr!(ds_train, basis_fitting)
150156
# Create result folder
151157
curr_sampler = "srs"
152158
exp_path = "$res_path/$j-$curr_sampler-n$n/"
153159
run(`mkdir -p $exp_path`)
154160
# Fit and save results
155-
metrics_j = fit(exp_path, train_ds, test_ds, basis_fitting; vref_dict=vref_dict)
161+
metrics_j = fit(exp_path, ds_train, ds_test, basis_fitting; vref_dict=vref_dict)
156162
metrics_j = merge(OrderedDict("exp_number" => j,
157163
"method" => "$curr_sampler",
158164
"batch_size_prop" => n/N,

0 commit comments

Comments
 (0)