|
| 1 | +# The following code is lazily loaded from Clustering.jl using LazyModules.jl |
| 2 | +# Code adapted from @cormullion's ColorSchemeTools (https://github.com/JuliaGraphics/ColorSchemeTools.jl) |
| 3 | + |
| 4 | +# The following type definition is taken from from Clustering.jl for the kwarg `init`: |
| 5 | +const ClusteringInitType = Union{ |
| 6 | + Symbol,Clustering.SeedingAlgorithm,AbstractVector{<:Integer} |
| 7 | +} |
| 8 | + |
| 9 | +const KMEANS_DEFAULT_COLORSPACE = RGB{Float32} |
| 10 | + |
| 11 | +""" |
| 12 | + KMeansQuantization([T=RGB,] ncolors) |
| 13 | +
|
| 14 | +Quantize colors by applying the K-means method, where `ncolors` corresponds to |
| 15 | +the amount of clusters and output colors. |
| 16 | +
|
| 17 | +The colorspace `T` in which K-means are computed defaults to `RGB`. |
| 18 | +
|
| 19 | +## Optional arguments |
| 20 | +The following keyword arguments from Clustering.jl can be specified: |
| 21 | +- `init`: specifies how cluster seeds are initialized |
| 22 | +- `maxiter`: maximum number of iterations |
| 23 | +- `tol`: minimal allowed change of the objective during convergence. |
| 24 | + The algorithm is considered to be converged when the change of objective value between |
| 25 | + consecutive iterations drops below `tol`. |
| 26 | +
|
| 27 | +The default values are carried over from are imported from Clustering.jl. |
| 28 | +For more details, refer to the [documentation](https://juliastats.org/Clustering.jl/stable/) |
| 29 | +of Clustering.jl. |
| 30 | +""" |
| 31 | +struct KMeansQuantization{T<:Colorant,I<:ClusteringInitType,R<:AbstractRNG} <: |
| 32 | + AbstractColorQuantizer |
| 33 | + ncolors::Int |
| 34 | + maxiter::Int |
| 35 | + tol::Float64 |
| 36 | + init::I |
| 37 | + rng::R |
| 38 | + |
| 39 | + function KMeansQuantization( |
| 40 | + T::Type{<:Colorant}, |
| 41 | + ncolors::Integer; |
| 42 | + init=Clustering._kmeans_default_init, |
| 43 | + maxiter=Clustering._kmeans_default_maxiter, |
| 44 | + tol=Clustering._kmeans_default_tol, |
| 45 | + rng=GLOBAL_RNG, |
| 46 | + ) |
| 47 | + ncolors ≥ 2 || |
| 48 | + throw(ArgumentError("K-means clustering requires ncolors ≥ 2, got $(ncolors).")) |
| 49 | + return new{T,typeof(init),typeof(rng)}(ncolors, maxiter, tol, init, rng) |
| 50 | + end |
| 51 | +end |
| 52 | +function KMeansQuantization(ncolors::Integer; kwargs...) |
| 53 | + return KMeansQuantization(KMEANS_DEFAULT_COLORSPACE, ncolors; kwargs...) |
| 54 | +end |
| 55 | + |
| 56 | +function (alg::KMeansQuantization{T})(cs::AbstractArray{<:Colorant}) where {T} |
| 57 | + # Clustering on the downsampled image already generates good enough colormap estimation. |
| 58 | + # This significantly reduces the algorithmic complexity. |
| 59 | + cs = _restrict_to(cs, alg.ncolors * 100) |
| 60 | + return _kmeans(alg, convert.(T, cs)) |
| 61 | +end |
| 62 | + |
| 63 | +function _kmeans(alg::KMeansQuantization, cs::AbstractArray{<:Colorant{T,N}}) where {T,N} |
| 64 | + data = reshape(channelview(cs), N, :) |
| 65 | + R = Clustering.kmeans( |
| 66 | + data, alg.ncolors; maxiter=alg.maxiter, tol=alg.tol, init=alg.init, rng=alg.rng |
| 67 | + ) |
| 68 | + return colorview(eltype(cs), R.centers) |
| 69 | +end |
0 commit comments