Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add p-Wasserstein distance #158

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
version = "0.8.2"

[deps]
Cbc = "9961bab8-2fa3-5c5a-9d89-47fab24efd76"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

Expand Down
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ This package also provides optimized functions to compute column-wise and pairwi
* Normalized root mean squared deviation
* Bray-Curtis dissimilarity
* Bregman divergence
* Wasserstein distance

For `Euclidean distance`, `Squared Euclidean distance`, `Cityblock distance`, `Minkowski distance`, and `Hamming distance`, a weighted version is also provided.

Expand Down Expand Up @@ -173,6 +174,7 @@ Each distance corresponds to a distance type. The type name and the correspondin
| WeightedMinkowski | `wminkowski(x, y, w, p)` | `sum(abs(x - y).^p .* w) ^ (1/p)` |
| WeightedHamming | `whamming(x, y, w)` | `sum((x .!= y) .* w)` |
| Bregman | `bregman(F, ∇, x, y; inner = LinearAlgebra.dot)` | `F(x) - F(y) - inner(∇(y), x - y)` |
| p-Wasserstein | `wasserstein(a, b, p)` | See (2,5) [here](https://arxiv.org/abs/1801.07745) |

**Note:** The formulas above are using *Julia*'s functions. These formulas are mainly for conveying the math concepts in a concise way. The actual implementation may use a faster way. The arguments `x` and `y` are arrays of real numbers; `k` and `l` are arrays of distinct elements of any kind; a and b are arrays of Bools; and finally, `p` and `q` are arrays forming a discrete probability distribution and are therefore both expected to sum to one.

Expand Down Expand Up @@ -203,6 +205,11 @@ julia> pairwise(Euclidean(1e-12), x, x)
0.0
```

## Notes on Wasserstein distance

The p-Wasserstein distances can only be calculated for values of the same probability simplex (i.e. non-negative real values with sum 1)

The calculation of the p-Wasserstein distance contains the solution of a linear program in `N^2` variables. This metric is quite expensive to calculate.

## Benchmarks

Expand Down
3 changes: 3 additions & 0 deletions src/Distances.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ export
RMSDeviation,
NormRMSDeviation,
Bregman,
Wasserstein,

# convenient functions
euclidean,
Expand Down Expand Up @@ -91,6 +92,7 @@ export
bhattacharyya,
hellinger,
bregman,
wasserstein,

haversine,

Expand All @@ -107,5 +109,6 @@ include("haversine.jl")
include("mahalanobis.jl")
include("bhattacharyya.jl")
include("bregman.jl")
include("wasserstein.jl")

end # module end
73 changes: 73 additions & 0 deletions src/wasserstein.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Wasserstein distance

using JuMP: Model, AffExpr, with_optimizer, @variable, @constraint,
@objective, add_to_expression!, optimize!, termination_status,
objective_value
import JuMP
import Cbc

struct Wasserstein <: Metric
p::Float64

function Wasserstein(p::Float64)
@assert p >= 1
new(p)
end
end

Wasserstein() = Wasserstein(1.0)

function (dist::Wasserstein)(a::AbstractArray{T}, b::AbstractArray{T}) where {T}
@assert length(a) == length(b)

isempty(a) && return zero(T)

@assert isapprox(sum(a), 1.0, atol=1e-6) "sum(a) needs to be ~1 but is $(sum(a))"
@assert isapprox(sum(b), 1.0, atol=1e-6) "sum(b) needs to be ~1 but is $(sum(b))"

model = make_wasserstein_model(a, b, dist.p)
optimize!(model)
@assert termination_status(model) == JuMP.MOI.OPTIMAL

objective_value(model)^(1/dist.p)
end

"""
Create JuMP `Model` for linear program to calculate the p-Wasserstein distance
of two discrete vectors from the same probability simplex. See also formula
(2.5) in [Optimal Transport on Discrete Domains](https://arxiv.org/abs/1801.07745).
"""
function make_wasserstein_model(a::AbstractArray, b::AbstractArray, p::Float64) :: Model
model = Model(with_optimizer(Cbc.Optimizer, logLevel=0))

N = length(a)
T = @variable(model, T[1:N, 1:N] >= 0)

for i in 1:N
row_expression = AffExpr()
for j in 1:N
add_to_expression!(row_expression, 1.0, T[i, j])
end
@constraint(model, row_expression == a[i])
end

for j in 1:N
column_expression = AffExpr()
for i in 1:N
add_to_expression!(column_expression, 1.0, T[i, j])
end
@constraint(model, column_expression == b[j])
end

objective_expression = AffExpr()
for i in 1:N
for j in 1:N
add_to_expression!(objective_expression, abs(i - j)^p, T[i, j])
end
end
@objective(model, Min, objective_expression)

model
end

wasserstein(a::AbstractArray, b::AbstractArray, p::Float64=1.0) = Wasserstein(p)(a, b)
62 changes: 45 additions & 17 deletions test/test_dists.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Unit tests for Distances

"Mapx `x` to the probability simplex"
function positive_and_normed(x)
abs_x = abs.(x)
M = sum(abs_x)
abs_x./M
end

function test_metricity(dist, x, y, z)
@testset "Test metricity of $(typeof(dist))" begin
@test dist(x, y) == evaluate(dist, x, y)
Expand Down Expand Up @@ -117,6 +124,11 @@ end
test_metricity(RenyiDivergence(2), p, q, r)
test_metricity(RenyiDivergence(10), p, q, r)
test_metricity(JSDivergence(), p, q, r)

let x′, y′, z′
x′, y′, z′ = positive_and_normed.([x, y, z])
test_metricity(Wasserstein(), x′, y′, z′)
end
end

@testset "individual metrics" begin
Expand Down Expand Up @@ -177,6 +189,15 @@ end
@test weuclidean(x, y, w) == sqrt(wsqeuclidean(x, y, w))
@test wcityblock(x, y, w) ≈ dot(abs.(x - vec(y)), w)
@test wminkowski(x, y, w, 2) ≈ weuclidean(x, y, w)

let x′, y′
x′, y′ = positive_and_normed.([x, y])
@test wasserstein(x′, y′) ≈ 0.471861471 atol=1e-6
@test wasserstein(x′, y′) != wasserstein(x′, y′, 2.0)
@test wasserstein(x′, y′, 2.0) ≈ 0.68692173634 atol=1e-6
end
@test_throws AssertionError wasserstein(x, y)
@test_throws AssertionError Wasserstein(0.5)
end

# Test ChiSq doesn't give NaN at zero
Expand Down Expand Up @@ -267,6 +288,7 @@ end #testset
@test isa(renyi_divergence(a, b, 2.0), T)
@test braycurtis(a, b) == 0.0
@test isa(braycurtis(a, b), T)
@test isa(wasserstein(a, b), T)

w = T[]
@test isa(whamming(a, b, w), T)
Expand Down Expand Up @@ -422,6 +444,19 @@ function test_colwise(dist, x, y, T)
end
end

function positive_and_normed_colwise(x, T)
rows, cols = size(x)
X = zeros(T, rows, cols)
@assert size(x) == size(X)
for i in 1:cols
normed_col = positive_and_normed(x[:, i])
for j in 1:rows
X[j, i] = normed_col[j]
end
end
X
end

@testset "column-wise metrics on $T" for T in (Float64, F64)
m = 5
n = 8
Expand Down Expand Up @@ -478,6 +513,11 @@ end

test_colwise(SqMahalanobis(Q), X, Y, T)
test_colwise(Mahalanobis(Q), X, Y, T)
let X′, Y′
X′ = positive_and_normed_colwise(X, T)
Y′ = positive_and_normed_colwise(Y, T)
test_colwise(Wasserstein(), X′, Y′, T)
end
end

function test_pairwise(dist, x, y, T)
Expand Down Expand Up @@ -555,6 +595,11 @@ end

test_pairwise(SqMahalanobis(Q), X, Y, T)
test_pairwise(Mahalanobis(Q), X, Y, T)
let X′, Y′
X′ = positive_and_normed_colwise(X, T)
Y′ = positive_and_normed_colwise(Y, T)
test_pairwise(Wasserstein(), X′, Y′, T)
end
end

@testset "Euclidean precision" begin
Expand Down Expand Up @@ -595,20 +640,3 @@ end
end
@test bregman(F, ∇, p, q) ≈ ISdist(p, q)
end

#=
@testset "zero allocation colwise!" begin
d = Euclidean()
a = rand(2, 41)
b = rand(2, 41)
z = zeros(41)
colwise!(z, d, a, b)
# This fails when bounds checking is enforced
bounds = Base.JLOptions().check_bounds
if bounds == 0
@test (@allocated colwise!(z, d, a, b)) == 0
else
@test_broken (@allocated colwise!(z, d, a, b)) == 0
end
end
=#