Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
*.jl.*.cov
*.jl.cov
*.jl.mem
/Manifest.toml
Manifest.toml
/docs/build/
10 changes: 10 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,15 @@ uuid = "652893fb-f6a0-4a00-a44a-7fb8fac69e01"
authors = ["Adrian Hill <[email protected]>"]
version = "0.1.0"

[deps]
Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
ImageBase = "c817782e-172a-44cc-b673-b171935fbb9e"
LazyModules = "8cdb02fc-e678-4876-92c5-9defec4f444e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[compat]
Clustering = "0.14.3"
Colors = "0.12"
LazyModules = "0.3"
julia = "1.6"
18 changes: 17 additions & 1 deletion src/ColorQuantization.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@
module ColorQuantization

# Write your package code here.
using Colors
using ImageBase: channelview, colorview, floattype, restrict
using Random: AbstractRNG, GLOBAL_RNG
using LazyModules: @lazy
#! format: off
@lazy import Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"
#! format: on

abstract type AbstractColorQuantizer end

include("api.jl")
include("utils.jl")
include("uniform.jl")
include("clustering.jl") # lazily loaded

export AbstractColorQuantizer, quantize
export UniformQuantization, KMeansQuantization

end
16 changes: 16 additions & 0 deletions src/api.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
quantize([T,] cs, alg)

Apply color quantization algorithm `alg` to an iterable collection of Colorants,
e.g. an image or any `AbstractArray`.
The return type `T` can be specified and defaults to the element type of `cs`.
"""
function quantize(cs::AbstractArray{T}, alg::AbstractColorQuantizer) where {T<:Colorant}
return quantize(T, cs, alg)
end

function quantize(
::Type{T}, cs::AbstractArray{<:Colorant}, alg::AbstractColorQuantizer
) where {T}
return convert.(T, alg(cs)[:])
end
69 changes: 69 additions & 0 deletions src/clustering.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# The following code is lazily loaded from Clustering.jl using LazyModules.jl
# Code adapted from @cormullion's ColorSchemeTools (https://github.com/JuliaGraphics/ColorSchemeTools.jl)

# The following type definition is taken from from Clustering.jl for the kwarg `init`:
const ClusteringInitType = Union{
Symbol,Clustering.SeedingAlgorithm,AbstractVector{<:Integer}
}

const KMEANS_DEFAULT_COLORSPACE = RGB{Float32}

"""
KMeansQuantization([T=RGB,] ncolors)

Quantize colors by applying the K-means method, where `ncolors` corresponds to
the amount of clusters and output colors.

The colorspace `T` in which K-means are computed defaults to `RGB`.

## Optional arguments
The following keyword arguments from Clustering.jl can be specified:
- `init`: specifies how cluster seeds are initialized
- `maxiter`: maximum number of iterations
- `tol`: minimal allowed change of the objective during convergence.
The algorithm is considered to be converged when the change of objective value between
consecutive iterations drops below `tol`.

The default values are carried over from are imported from Clustering.jl.
For more details, refer to the [documentation](https://juliastats.org/Clustering.jl/stable/)
of Clustering.jl.
"""
struct KMeansQuantization{T<:Colorant,I<:ClusteringInitType,R<:AbstractRNG} <:
AbstractColorQuantizer
ncolors::Int
maxiter::Int
tol::Float64
init::I
rng::R

function KMeansQuantization(
T::Type{<:Colorant},
ncolors::Integer;
init=Clustering._kmeans_default_init,
maxiter=Clustering._kmeans_default_maxiter,
tol=Clustering._kmeans_default_tol,
rng=GLOBAL_RNG,
)
ncolors ≥ 2 ||
throw(ArgumentError("K-means clustering requires ncolors ≥ 2, got $(ncolors)."))
return new{T,typeof(init),typeof(rng)}(ncolors, maxiter, tol, init, rng)
end
end
function KMeansQuantization(ncolors::Integer; kwargs...)
return KMeansQuantization(KMEANS_DEFAULT_COLORSPACE, ncolors; kwargs...)
end

function (alg::KMeansQuantization{T})(cs::AbstractArray{<:Colorant}) where {T}
# Clustering on the downsampled image already generates good enough colormap estimation.
# This significantly reduces the algorithmic complexity.
cs = _restrict_to(cs, alg.ncolors * 100)
return _kmeans(alg, convert.(T, cs))
end

function _kmeans(alg::KMeansQuantization, cs::AbstractArray{<:Colorant{T,N}}) where {T,N}
data = reshape(channelview(cs), N, :)
R = Clustering.kmeans(
data, alg.ncolors; maxiter=alg.maxiter, tol=alg.tol, init=alg.init, rng=alg.rng
)
return colorview(eltype(cs), R.centers)
end
20 changes: 20 additions & 0 deletions src/uniform.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
UniformQuantization(n::Int)

Quantize colors in RGB color space by dividing each dimension of the ``[0, 1]³``
RGB color cube into `n` equidistant steps. This results in a grid with ``(n+1)³`` points.
Each color in `cs` is then quantized to the closest point on the grid. Only unique colors
are returned. The amount of output colors is therefore bounded by ``(n+1)³``.
"""
struct UniformQuantization <: AbstractColorQuantizer
n::Int

function UniformQuantization(n)
return n < 1 ? throw(ArgumentError("n has to be ≥ 1, got $(n).")) : new(n)
end
end

(alg::UniformQuantization)(cs::AbstractArray{<:Colorant}) = alg(convert.(RGB{Float32}, cs))
function (alg::UniformQuantization)(cs::AbstractArray{T}) where {T<:RGB{<:AbstractFloat}}
return colorview(T, unique(round.(channelview(cs[:]) * alg.n); dims=2) / alg.n)
end
Copy link
Member

@johnnychen94 johnnychen94 Sep 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is okay. But I get something faster by making it a plain lookup table (The key idea is to pass T,N to the compiler so that we can do the computation in the compilation stage using generated function):

using ImageCore, BenchmarkTools

struct UniformQuantization{T, N} end

@inline function quantize(::UniformQuantization{T,N}, x::T) where {T,N}
    _naive_quantize(UniformQuantization{T,N}(), x)
end

@inline function quantize(::UniformQuantization{T,N}, x::T, n) where {T<:Union{N0f8,},N}
    @inbounds _build_lookup_table(UniformQuantization{T,N}())[reinterpret(FixedPointNumbers.rawtype(T), x) + 1]
end

# for every combination of {T,N}, this is only done once during the compilation
@generated function _build_lookup_table(::UniformQuantization{T,N}) where {T<:FixedPoint,N}
    # This errors when T requires more than 32 bits
    tmax = typemax(FixedPointNumbers.rawtype(T))
    table = Vector{T}(undef, tmax + 1)
    for raw_x in zero(tmax):tmax
        x = reinterpret(T, raw_x)
        table[raw_x + 1] = _naive_quantize(UniformQuantization{T,N}(), x)
    end
    return table
end

function _naive_quantize(::UniformQuantization{T,N}, x::T) where {T,N}
    x  0 && return 1 / (2 * N)
    x  1 && return (2 * N - 1) / (2 * N)
    return (round(x * N - T(0.5)) + T(0.5)) / N
end

and I get

X = rand(N0f8, 512, 512)
Q = UniformQuantization{N0f8,32}()
@btime quantize.(Ref(Q), $X); #   95.689 μs (7 allocations: 256.24 KiB)

In the meantime, the current version is:

X = rand(N0f8, 512, 512)
Q = UniformQuantization(32)
@btime Q(X); #   5.336 ms (17 allocations: 6.01 MiB)

This trick is only applicable for fixed point numbers. Do you think it's worth the change?
For float-point numbers, a "round" process is needed to fully utilize the lookup table.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I now have both in the code for maximum flexibility w.r.t. input data types. However the benchmarks are a bit more complicated:

alg = UniformQuantization(32)
for T in (N0f8, Float32, RGB{N0f8}, RGB{Float32})
    println("Timing $T:")
    img = rand(T, 512, 512)
    @btime alg($img)
    @btime alg.($img) # skips `unique`
end
Timing N0f8:
  25.436 ms (15 allocations: 1.50 MiB)
  91.459 μs (4 allocations: 256.17 KiB)
Timing Float32:
  4.659 ms (17 allocations: 6.01 MiB)
  75.006 μs (3 allocations: 1.00 MiB)
Timing RGB{N0f8}:
  33.910 ms (38 allocations: 1.38 MiB)
  270.763 μs (3 allocations: 768.06 KiB)
Timing RGB{Float32}:
  7.663 ms (39 allocations: 5.28 MiB)
  344.923 μs (3 allocations: 3.00 MiB)

And building the look-up table for N0f32 freezes my Julia session.

Copy link
Member

@johnnychen94 johnnychen94 Oct 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We certainly cannot make it a generic implementation here -- too risky here whether we can delegate so much computation to compiler. We only need those really matters in practice.
For generic types, a runtime lookup table (using Dict{Type,Vector} ) can be used instead (but whether it's faster than the naive version is unclear to me)

9 changes: 9 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Reduce the size of img until there are at most n elements left
function _restrict_to(img, n)
length(img) <= n && return img
out = restrict(img)
while length(out) > n
out = restrict(out)
end
return out
end
34 changes: 0 additions & 34 deletions test/Manifest.toml

This file was deleted.

11 changes: 11 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,13 @@
[deps]
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestImages = "5e47fb64-e119-507b-a336-dd2b206d9990"

[compat]
Colors = "0.12"
ReferenceTests = "0.10"
StableRNGs = "1"
TestImages = "1"
1 change: 1 addition & 0 deletions test/references/KMeansQuantization8.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
████████
1 change: 1 addition & 0 deletions test/references/KMeansQuantization8_HSV.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
████████
1 change: 1 addition & 0 deletions test/references/UniformQuantization4.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
███████████████████████████
1 change: 1 addition & 0 deletions test/references/UniformQuantization4_HSV.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
███████████████████████████
33 changes: 32 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,37 @@
using ColorQuantization
using Test
using TestImages, ReferenceTests
using Colors
using Random, StableRNGs

rng = StableRNG(123)
Random.seed!(rng, 34568)

img = testimage("peppers")

algs_deterministic = Dict(
"UniformQuantization4" => UniformQuantization(4),
"KMeansQuantization8" => KMeansQuantization(8; rng=rng),
)

@testset "ColorQuantization.jl" begin
# Write your tests here.
# Reference tests on deterministic methods
@testset "Reference tests" begin
for (name, alg) in algs_deterministic
@testset "$name" begin
cs = quantize(img, alg)
@test eltype(cs) == eltype(img)
@test_reference "references/$(name).txt" cs

cs = quantize(HSV{Float16}, img, alg)
@test eltype(cs) == HSV{Float16}
@test_reference "references/$(name)_HSV.txt" cs
end
end
end

@testset "Error messages" begin
@test_throws ArgumentError UniformQuantization(0)
@test_throws ArgumentError KMeansQuantization(0)
end
end