Skip to content

Commit ce063b4

Browse files
committed
Adding weights normalization. Improvements in simple example
1 parent cbd9179 commit ce063b4

File tree

5 files changed

+83
-14
lines changed

5 files changed

+83
-14
lines changed

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
33
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
44
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
5+
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
56
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
67
StreamingSampling = "e1325ea1-13b9-452d-9115-4dbfefa12b3b"
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
[deps]
2+
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
23
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
34
StreamingSampling = "e1325ea1-13b9-452d-9115-4dbfefa12b3b"
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# helper
2+
padlims(v; frac=0.10) = begin
3+
vmin, vmax = extrema(v)
4+
span = max(vmax - vmin, eps())
5+
(vmin - frac*span, vmax + frac*span)
6+
end
7+
8+
function plot_weights(ws, inds;
9+
size=(640, 480), dpi=300,
10+
guidefs=20, tickfs=14, legendfs=16,
11+
grid=:y,
12+
margins=(left=30, right=40, bottom=30, top=20))
13+
14+
perm = sortperm(ws)
15+
s = ws[perm]
16+
x = collect(eachindex(s))
17+
m = in(Set(inds)).(perm)
18+
19+
# ensure margins are AbsoluteLength in px (avoids Int + AbsoluteLength conflicts)
20+
lp = margins.left * Plots.px
21+
rp = margins.right * Plots.px
22+
bp = margins.bottom* Plots.px
23+
tp = margins.top * Plots.px
24+
25+
plt = plot(size=size, dpi=dpi, grid=grid, framestyle=:box,
26+
left_margin=lp, right_margin=rp,
27+
bottom_margin=bp, top_margin=tp,
28+
xguidefont=font(guidefs), yguidefont=font(guidefs),
29+
xtickfont=font(tickfs), ytickfont=font(tickfs),
30+
legendfont=font(legendfs))
31+
32+
scatter!(plt, x, s;
33+
color=:gray, alpha=0.25,
34+
marker=:circle, markersize=8, markerstrokewidth=0,
35+
label="All",
36+
xlabel="Sorted element indices",
37+
ylabel="Weights",
38+
ylims=padlims(s))
39+
40+
if any(m)
41+
scatter!(plt, x[m], s[m];
42+
color=:red, marker=:utriangle,
43+
markersize=10, markerstrokewidth=0,
44+
label="Selected")
45+
else
46+
plot!(plt, NaN, NaN; label="Selected")
47+
end
48+
49+
return plt
50+
end
51+
Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using StreamingSampling
22
using StatsBase: sample, Weights
3+
using Plots
34

45
# Define file paths
56
base = haskey(ENV, "BASE_PATH") ? ENV["BASE_PATH"] : "../../"
@@ -9,14 +10,17 @@ file_paths = ["$path/data1.txt",
910
"$path/data3.txt",
1011
"$path/data4.txt"];
1112

13+
include("$base/examples/simple-example/plot_weights.jl"); #hide
14+
1215
# Define sample size
13-
n = 100;
16+
n = 30;
1417

1518
# Streaming weighted sampling
1619
ws = compute_weights(file_paths; chunksize=500, subchunksize=100)
1720
inds_w = sample(1:length(ws), Weights(ws), n; replace=false)
21+
plot_weights(ws, inds_w) #hide
1822

1923
# Streaming maximum entropy sampling
2024
s = UPmaxentropy(inclusion_prob(ws, n))
2125
inds_me = findall(s .== 1)
22-
26+
plot_weights(ws, inds_me) #hide

src/Weights.jl

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ function compute_weights(file_paths::Vector{String};
77
subchunksize=200,
88
buffersize=32,
99
max=Inf,
10-
randomized=true)
10+
randomized=true,
11+
normalize=true)
1112
ch, N = chunk_iterator(file_paths;
1213
read_element=read_element,
1314
chunksize=subchunksize,
@@ -16,11 +17,16 @@ function compute_weights(file_paths::Vector{String};
1617
if max == Inf
1718
max = N
1819
end
19-
return compute_weights(ch;
20-
create_feature=create_feature,
21-
chunksize=chunksize,
22-
subchunksize=subchunksize,
23-
max=max)
20+
ws = compute_weights(ch;
21+
create_feature=create_feature,
22+
chunksize=chunksize,
23+
subchunksize=subchunksize,
24+
max=max)
25+
if normalize
26+
wmin, wmax = minimum(ws), maximum(ws)
27+
ws = (ws .- wmin) ./ (wmax - wmin)
28+
end
29+
return ws
2430
end
2531

2632
function compute_weights(A::Vector;
@@ -30,7 +36,8 @@ function compute_weights(A::Vector;
3036
subchunksize=100,
3137
buffersize=32,
3238
max=Inf,
33-
randomized=true)
39+
randomized=true,
40+
normalize=true)
3441
ch, N = chunk_iterator(A;
3542
read_element=read_element,
3643
chunksize=subchunksize,
@@ -39,11 +46,16 @@ function compute_weights(A::Vector;
3946
if max == Inf
4047
max = N
4148
end
42-
return compute_weights(ch;
43-
create_feature=create_feature,
44-
chunksize=chunksize,
45-
subchunksize=subchunksize,
46-
max=max)
49+
ws = compute_weights(ch;
50+
create_feature=create_feature,
51+
chunksize=chunksize,
52+
subchunksize=subchunksize,
53+
max=max)
54+
if normalize
55+
wmin, wmax = minimum(ws), maximum(ws)
56+
ws = (ws .- wmin) ./ (wmax - wmin)
57+
end
58+
return ws
4759
end
4860

4961
function compute_weights(ch::Channel;

0 commit comments

Comments
 (0)