Skip to content

Commit

Permalink
Add EarthMoversDistance in Julia
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Feb 6, 2020
1 parent 59edfb2 commit 4b17ca7
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 9 deletions.
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ version = "0.1.0"

[deps]
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
ConfParser = "88353bc9-fd38-507d-a820-d3b43837d6b9"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -18,7 +17,6 @@ NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ScikitLearn = "3646fa90-6ef7-5e7e-9f22-8aca16db6324"
Expand Down
12 changes: 5 additions & 7 deletions src/Histograms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,10 @@ Functions in this module:
"""
module Histograms

using PyCall
using NPZ
using ..Utilities

using Conda
Conda.add("scipy")
scsta = pyimport_conda("scipy.stats","")
include("wasserstein.jl")

"""
A simple struct to store samples for empirical PDFs (histograms, distances etc.)
Expand Down Expand Up @@ -196,13 +193,14 @@ Returns:
"""
function W1(u_samples::AbstractVector, v_samples::AbstractVector;
normalize = false)
return if !normalize
scsta.wasserstein_distance(u_samples, v_samples)
d = wasserstein_distance(u_samples, v_samples)
if !normalize
return d
else
u_m, u_M = extrema(u_samples)
v_m, v_M = extrema(v_samples)
L = max(u_M, v_M) - min(u_m, v_m)
scsta.wasserstein_distance(u_samples, v_samples) / L
return d / L
end
end

Expand Down
77 changes: 77 additions & 0 deletions src/wasserstein.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#####
##### Wasserstein distance
#####

function pysearchsorted(a,b;side="left")
if side == "left"
return searchsortedfirst.(Ref(a),b) .- 1
else
return searchsortedlast.(Ref(a),b)
end
end

function _cdf_distance(p, u_values, v_values, u_weights=nothing, v_weights=nothing)
_validate_distribution!(u_values, u_weights)
_validate_distribution!(v_values, v_weights)

u_sorter = sortperm(u_values)
v_sorter = sortperm(v_values)

all_values = vcat(u_values, v_values)
sort!(all_values)

# Compute the differences between pairs of successive values of u and v.
deltas = diff(all_values)

# Get the respective positions of the values of u and v among the values of
# both distributions.
u_cdf_indices = pysearchsorted(u_values[u_sorter],all_values[1:end-1], side="right")
v_cdf_indices = pysearchsorted(v_values[v_sorter],all_values[1:end-1], side="right")

# Calculate the CDFs of u and v using their weights, if specified.
if u_weights == nothing
u_cdf = (u_cdf_indices) / length(u_values)
else
u_sorted_cumweights = vcat([0], cumsum(u_weights[u_sorter]))
u_cdf = u_sorted_cumweights[u_cdf_indices] / u_sorted_cumweights[end]
end

if v_weights == nothing
v_cdf = (v_cdf_indices) / length(v_values)
else
v_sorted_cumweights = vcat([0], cumsum(v_weights[v_sorter]))
v_cdf = v_sorted_cumweights[v_cdf_indices] / v_sorted_cumweights[end]
end

# Compute the value of the integral based on the CDFs.
if p == 1
return sum(abs.(u_cdf - v_cdf) .* deltas)
end
if p == 2
return sqrt(sum((u_cdf - v_cdf).^2 .* deltas))
end
return sum(abs.(u_cdf - v_cdf).^p .* deltas)^(1/p)
end

function _validate_distribution!(vals, weights)
# Validate the value array.
length(vals) == 0 && throw(ValueError("Distribution can't be empty."))
# Validate the weight array, if specified.
if weights nothing
if length(weights) != length(vals)
throw(ValueError("Value and weight array-likes for the same
empirical distribution must be of the same size."))
end
any(weights .< 0) && throw(ValueError("All weights must be non-negative."))
if !(0 < sum(weights) < Inf)
throw(ValueError("Weight array-like sum must be positive and
finite. Set as None for an equal distribution of
weight."))
end
end
return nothing
end

function wasserstein_distance(u_values, v_values, u_weights=nothing, v_weights=nothing)
return _cdf_distance(1, u_values, v_values, u_weights, v_weights)
end

0 comments on commit 4b17ca7

Please sign in to comment.