Skip to content

Commit

Permalink
Orad/fix emulator bug repeat call (#281)
Browse files Browse the repository at this point in the history
* bugfix repeated calls, by skipping build

* typo in test
  • Loading branch information
odunbar authored Feb 13, 2024
1 parent b6a80c9 commit 6fcab82
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/GaussianProcess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ function build_models!(

# Number of models (We are fitting one model per output dimension, as data is decorrelated)
models = gp.models
if length(gp.models) > 0 # check to see if gp already contains models
@warn "GaussianProcess already built. skipping..."
return
end
N_models = size(output_values, 1) #size(transformed_data)[1]


Expand Down Expand Up @@ -228,6 +232,11 @@ function build_models!(

# Number of models (We are fitting one model per output dimension, as data is decorrelated)
models = gp.models
if length(gp.models) > 0 # check to see if gp already contains models
@warn "GaussianProcess already built. skipping..."
return
end

N_models = size(output_values, 1) #size(transformed_data)[1]

if gp.kernel === nothing
Expand Down
4 changes: 4 additions & 0 deletions src/ScalarRandomFeature.jl
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,10 @@ function build_models!(


rfms = get_rfms(srfi)
if length(rfms) > 0
@warn "ScalarRandomFeatureInterface already built. skipping..."
return
end
fitted_features = get_fitted_features(srfi)
n_features = get_n_features(srfi)
batch_sizes = get_batch_sizes(srfi)
Expand Down
5 changes: 5 additions & 0 deletions src/VectorRandomFeature.jl
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,11 @@ function build_models!(
n_hp = calculate_n_hyperparameters(input_dim, output_dim, kernel_structure)

rfms = get_rfms(vrfi)
if length(rfms) > 0
@warn "VectorRandomFeatureInterface already built. skipping..."
return
end

fitted_features = get_fitted_features(vrfi)
n_features = get_n_features(vrfi)
batch_sizes = get_batch_sizes(vrfi)
Expand Down
37 changes: 37 additions & 0 deletions test/GaussianProcess/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ using CalibrateEmulateSample.DataContainers
@test gp1.prediction_type == pred_type
@test gp1.alg_reg_noise == 1e-4


em1 = Emulator(
gp1,
iopairs,
Expand All @@ -71,6 +72,24 @@ using CalibrateEmulateSample.DataContainers
retained_svd_frac = 1.0,
)

@test_logs (:warn,) Emulator(
gp1,
iopairs,
obs_noise_cov = nothing,
normalize_inputs = false,
standardize_outputs = false,
retained_svd_frac = 1.0,
) # check that gp1 does not get more models added under second call
Emulator(
gp1,
iopairs,
obs_noise_cov = nothing,
normalize_inputs = false,
standardize_outputs = false,
retained_svd_frac = 1.0,
)
@test length(gp1.models) == 1

Emulators.optimize_hyperparameters!(em1)

μ1, σ1² = Emulators.predict(em1, new_inputs)
Expand Down Expand Up @@ -120,6 +139,24 @@ using CalibrateEmulateSample.DataContainers
retained_svd_frac = 1.0,
)

@test_logs (:warn,) Emulator(
gp3,
iopairs,
obs_noise_cov = nothing,
normalize_inputs = false,
standardize_outputs = false,
retained_svd_frac = 1.0,
)
Emulator(
gp3,
iopairs,
obs_noise_cov = nothing,
normalize_inputs = false,
standardize_outputs = false,
retained_svd_frac = 1.0,
)
@test length(gp3.models) == 1 # check that gp3 does not get more models added under repeated calls

Emulators.optimize_hyperparameters!(em3)


Expand Down
11 changes: 11 additions & 0 deletions test/RandomFeature/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,18 @@ rng = Random.MersenneTwister(seed)

# build emulators
em_srfi = Emulator(srfi, iopairs, obs_noise_cov = obs_noise_cov)
n_srfi = length(get_rfms(srfi))
em_vrfi = Emulator(vrfi, iopairs, obs_noise_cov = obs_noise_cov)
n_vrfi = length(get_rfms(vrfi))

# test under repeats
@test_logs (:warn,) Emulator(srfi, iopairs, obs_noise_cov = obs_noise_cov)
Emulator(srfi, iopairs, obs_noise_cov = obs_noise_cov)
@test length(get_rfms(srfi)) == n_srfi
@test_logs (:warn,) Emulator(vrfi, iopairs, obs_noise_cov = obs_noise_cov)
Emulator(vrfi, iopairs, obs_noise_cov = obs_noise_cov)
@test length(get_rfms(vrfi)) == n_vrfi


# just see if it prints something
@test_logs (:info,) Emulators.optimize_hyperparameters!(em_srfi)
Expand Down

0 comments on commit 6fcab82

Please sign in to comment.