From 6fcab82a690d71ef7c6b3810901da80ba1cfcb43 Mon Sep 17 00:00:00 2001 From: Oliver Dunbar <47412152+odunbar@users.noreply.github.com> Date: Tue, 13 Feb 2024 15:57:33 -0800 Subject: [PATCH] Orad/fix emulator bug repeat call (#281) * bugfix repeated calls, by skipping build * typo in test --- src/GaussianProcess.jl | 9 ++++++++ src/ScalarRandomFeature.jl | 4 ++++ src/VectorRandomFeature.jl | 5 +++++ test/GaussianProcess/runtests.jl | 37 ++++++++++++++++++++++++++++++++ test/RandomFeature/runtests.jl | 11 ++++++++++ 5 files changed, 66 insertions(+) diff --git a/src/GaussianProcess.jl b/src/GaussianProcess.jl index a86490063..7727cb58d 100644 --- a/src/GaussianProcess.jl +++ b/src/GaussianProcess.jl @@ -113,6 +113,10 @@ function build_models!( # Number of models (We are fitting one model per output dimension, as data is decorrelated) models = gp.models + if length(gp.models) > 0 # check to see if gp already contains models + @warn "GaussianProcess already built. skipping..." + return + end N_models = size(output_values, 1) #size(transformed_data)[1] @@ -228,6 +232,11 @@ function build_models!( # Number of models (We are fitting one model per output dimension, as data is decorrelated) models = gp.models + if length(gp.models) > 0 # check to see if gp already contains models + @warn "GaussianProcess already built. skipping..." + return + end + N_models = size(output_values, 1) #size(transformed_data)[1] if gp.kernel === nothing diff --git a/src/ScalarRandomFeature.jl b/src/ScalarRandomFeature.jl index 05a9c9987..2ea40a18b 100644 --- a/src/ScalarRandomFeature.jl +++ b/src/ScalarRandomFeature.jl @@ -305,6 +305,10 @@ function build_models!( rfms = get_rfms(srfi) + if length(rfms) > 0 + @warn "ScalarRandomFeatureInterface already built. skipping..." + return + end fitted_features = get_fitted_features(srfi) n_features = get_n_features(srfi) batch_sizes = get_batch_sizes(srfi) diff --git a/src/VectorRandomFeature.jl b/src/VectorRandomFeature.jl index 96a0ef231..7f4b40885 100644 --- a/src/VectorRandomFeature.jl +++ b/src/VectorRandomFeature.jl @@ -354,6 +354,11 @@ function build_models!( n_hp = calculate_n_hyperparameters(input_dim, output_dim, kernel_structure) rfms = get_rfms(vrfi) + if length(rfms) > 0 + @warn "VectorRandomFeatureInterface already built. skipping..." + return + end + fitted_features = get_fitted_features(vrfi) n_features = get_n_features(vrfi) batch_sizes = get_batch_sizes(vrfi) diff --git a/test/GaussianProcess/runtests.jl b/test/GaussianProcess/runtests.jl index 5935f834a..5c0001088 100644 --- a/test/GaussianProcess/runtests.jl +++ b/test/GaussianProcess/runtests.jl @@ -62,6 +62,7 @@ using CalibrateEmulateSample.DataContainers @test gp1.prediction_type == pred_type @test gp1.alg_reg_noise == 1e-4 + em1 = Emulator( gp1, iopairs, @@ -71,6 +72,24 @@ using CalibrateEmulateSample.DataContainers retained_svd_frac = 1.0, ) + @test_logs (:warn,) Emulator( + gp1, + iopairs, + obs_noise_cov = nothing, + normalize_inputs = false, + standardize_outputs = false, + retained_svd_frac = 1.0, + ) # check that gp1 does not get more models added under second call + Emulator( + gp1, + iopairs, + obs_noise_cov = nothing, + normalize_inputs = false, + standardize_outputs = false, + retained_svd_frac = 1.0, + ) + @test length(gp1.models) == 1 + Emulators.optimize_hyperparameters!(em1) μ1, σ1² = Emulators.predict(em1, new_inputs) @@ -120,6 +139,24 @@ using CalibrateEmulateSample.DataContainers retained_svd_frac = 1.0, ) + @test_logs (:warn,) Emulator( + gp3, + iopairs, + obs_noise_cov = nothing, + normalize_inputs = false, + standardize_outputs = false, + retained_svd_frac = 1.0, + ) + Emulator( + gp3, + iopairs, + obs_noise_cov = nothing, + normalize_inputs = false, + standardize_outputs = false, + retained_svd_frac = 1.0, + ) + @test length(gp3.models) == 1 # check that gp3 does not get more models added under repeated calls + Emulators.optimize_hyperparameters!(em3) diff --git a/test/RandomFeature/runtests.jl b/test/RandomFeature/runtests.jl index 6867040ee..75110bfa3 100644 --- a/test/RandomFeature/runtests.jl +++ b/test/RandomFeature/runtests.jl @@ -280,7 +280,18 @@ rng = Random.MersenneTwister(seed) # build emulators em_srfi = Emulator(srfi, iopairs, obs_noise_cov = obs_noise_cov) + n_srfi = length(get_rfms(srfi)) em_vrfi = Emulator(vrfi, iopairs, obs_noise_cov = obs_noise_cov) + n_vrfi = length(get_rfms(vrfi)) + + # test under repeats + @test_logs (:warn,) Emulator(srfi, iopairs, obs_noise_cov = obs_noise_cov) + Emulator(srfi, iopairs, obs_noise_cov = obs_noise_cov) + @test length(get_rfms(srfi)) == n_srfi + @test_logs (:warn,) Emulator(vrfi, iopairs, obs_noise_cov = obs_noise_cov) + Emulator(vrfi, iopairs, obs_noise_cov = obs_noise_cov) + @test length(get_rfms(vrfi)) == n_vrfi + # just see if it prints something @test_logs (:info,) Emulators.optimize_hyperparameters!(em_srfi)