Skip to content

Commit ffea773

Browse files
committed
Updates in atomistic examples. Change of inclusion probability function. Updates in tests.
1 parent 051128e commit ffea773

File tree

6 files changed

+53
-50
lines changed

6 files changed

+53
-50
lines changed

examples/atomistic/sme-iso17.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
#using Pkg
2+
#Pkg.develop(path="../../")
3+
14
using StreamingSampling
25

36
include("utils/utils.jl")

examples/atomistic/srs-vs-sme-aspirin-rmd17.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
#using Pkg
2+
#Pkg.develop(path="../../")
3+
14
using StreamingSampling
25

36
include("utils/utils.jl")

examples/atomistic/srs-vs-sme-hfo2.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
using Pkg
2+
Pkg.develop(path="../../")
3+
14
using StreamingSampling
25

36
include("utils/utils.jl")
@@ -140,7 +143,7 @@ calc_descr!(ds_train_rnd, basis_fitting)
140143
calc_descr!(ds_test_rnd, basis_fitting)
141144

142145
# Initialize StreamMaxEnt sampler ##############################################
143-
read_element(io) = read_element_extxyz(io)
146+
read_conf(x::Configuration) = x
144147
basis = ACE(species = [:C, :O, :H],
145148
body_order = 4,
146149
polynomial_degree = 8,
@@ -153,8 +156,8 @@ function create_feature(element::Vector; basis=basis)
153156
feature = sum(compute_local_descriptors(system, basis))
154157
return feature
155158
end
156-
sme = StreamMaxEnt(train_path;
157-
read_element=read_element,
159+
sme = StreamMaxEnt(ds_train_rnd.Configurations;
160+
read_element=read_conf,
158161
create_feature=create_feature,
159162
chunksize=2000,
160163
subchunksize=200)

examples/atomistic/srs-vs-sme-iso17.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1+
#using Pkg
2+
#Pkg.develop(path="../../")
3+
14
using StreamingSampling
25

36
include("utils/utils.jl")
47

58
# Define paths and create experiment folder
69
train_path = ["data/iso17/my_iso17_train.extxyz"]
710
test_path = ["data/iso17/my_iso17_test.extxyz"]
8-
res_path = "results-full-iso17/"
11+
res_path = "results-iso17/"
912
run(`mkdir -p $res_path`)
1013

1114
# Initialize StreamMaxEnt sampler ##############################################

src/IncluProbs.jl

Lines changed: 35 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,42 @@
1+
using JuMP
2+
using Ipopt
3+
14
# Inclusion probabilities
25
# Transform weights into first order inclusion probabilities
3-
# 1: p(w)=w*x+y
6+
# 1: p(w)=n*w/sum(w)
47
# 2: Sum constraint: sum(probs)=n => sum(p.(ws))=n
5-
# 3: Minimum probability constraint: p(w_min)=w_min
6-
#
7-
# Algebra: y = w_min*(1 - x)
8-
# x = (n - N*w_min) / (sum(ws) - N*w_min) == (n - N*w_min) / (N*(μ - w_min))
9-
function inclusion_prob(sampler::Sampler, n::Integer; highprec::Bool=false)
8+
function inclusion_prob(sampler::Sampler, n::Int)
109
@views ws = sampler.weights
1110
N = length(ws)
12-
@assert N > 0 "empty weights"
13-
14-
if !highprec
15-
wmin = minimum(ws)
16-
μ = mean(ws)
17-
den = N*- wmin)
18-
num = n - N*wmin
19-
20-
# If nearly singular, switch to high precision automatically
21-
if den == 0 || abs(den) 1e-12 * max(1.0, abs(N*μ))
22-
return inclusion_prob(sampler, n; highprec=true)
23-
end
24-
25-
x = num / den
26-
y = muladd(-wmin, x, wmin) # y = wmin*(1 - x)
27-
ps = @. muladd(x, ws, y) # p = x*ws + y
28-
29-
return ps
30-
else
31-
# Compute coefficients in high precision, then cast back
32-
ps = let
33-
setprecision(256) do
34-
W = big.(ws)
35-
NB = big(N)
36-
nB = big(n)
37-
wmin = minimum(W)
38-
μ = mean(W)
39-
den = NB*- wmin)
40-
num = nB - NB*wmin
41-
@assert den != 0 "Constraints are singular: mean(ws) == w_min"
42-
43-
x = num / den
44-
y = wmin*(1 - x)
45-
T = eltype(ws)
46-
T.( @. x*W + y )
47-
end
48-
end
49-
return ps
11+
12+
# Start with probabilities proportional to weights
13+
probs_proportional = n * ws / sum(ws)
14+
15+
# Check if already in [0,1] - if so, we're done!
16+
if all(0 .<= probs_proportional .<= 1)
17+
return probs_proportional
5018
end
19+
20+
# Otherwise, need to adjust with optimization
21+
model = Model(Ipopt.Optimizer)
22+
set_silent(model)
23+
24+
# Direct variables for each probability
25+
@variable(model, 0 <= p[i=1:N] <= 1)
26+
27+
# Constraint: Sum equals n
28+
@constraint(model, sum(p) == n)
29+
30+
# Objective: minimize deviation from proportional probabilities
31+
# This maintains the relative importance from weights
32+
@objective(model, Min, sum((p[i] - probs_proportional[i])^2 for i in 1:N))
33+
34+
optimize!(model)
35+
36+
if termination_status(model) != MOI.OPTIMAL && termination_status(model) != MOI.LOCALLY_SOLVED
37+
error("Could not find valid probabilities. Check that 0 < n <= N=$(N)")
38+
end
39+
40+
return value.(p)
5141
end
42+

test/runtests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@ using Statistics
1212
"data/data3.txt",
1313
"data/data4.txt"]
1414
sme = StreamMaxEnt(file_paths; chunksize=1000, subchunksize=100)
15-
n = 100
15+
n = 2351
1616
inds = StreamingSampling.sample(sme, n)
1717
ps = StreamingSampling.inclusion_prob(sme, n)
1818

1919
println("Checking sample size.")
2020
@test round(Int, sum(ps)) length(inds)
2121

22-
println("Checking sum(ps)==n.")
22+
println("Checking sum(ps)n.")
2323
@test round(Int, sum(ps)) n
2424

2525
println("Checking 0<=ps_i<=1.")

0 commit comments

Comments
 (0)