From 8a8614f792dbdbf8ca5131f1fdbc297157686faf Mon Sep 17 00:00:00 2001 From: Gabriele Bozzola Date: Tue, 15 Oct 2024 13:10:04 -0700 Subject: [PATCH] Add support for checkpoints This commit adds support for checkpoint and restarting a ClimaLand simulation. The functionality is tested in a bucket experiment, where I verify that saving a simulation to disk and restarting it leads to the same state as running the simulation in one go. In the process, I had to bump some packages (because I am using the latest version of ClimaUtilities). I could not get ClimaLand to stay compatible with 1.9, so I bumped the minimum version to 1.10. We should try to keep 1.10 as minimum version because it is LTS. --- .github/workflows/downgrade.yml | 2 +- Project.toml | 10 +-- docs/make.jl | 1 + docs/src/restarts.md | 104 ++++++++++++++++++++++ src/ClimaLand.jl | 1 + src/shared_utilities/checkpoints.jl | 128 ++++++++++++++++++++++++++++ src/shared_utilities/models.jl | 20 +++++ src/shared_utilities/utils.jl | 128 +++++++++++++++++++++++++++- test/Project.toml | 1 + test/runtests.jl | 3 + test/standalone/Bucket/restart.jl | 102 ++++++++++++++++++++++ 11 files changed, 492 insertions(+), 8 deletions(-) create mode 100644 docs/src/restarts.md create mode 100644 src/shared_utilities/checkpoints.jl create mode 100644 test/standalone/Bucket/restart.jl diff --git a/.github/workflows/downgrade.yml b/.github/workflows/downgrade.yml index e71e0934bf..68f1dbe876 100644 --- a/.github/workflows/downgrade.yml +++ b/.github/workflows/downgrade.yml @@ -20,7 +20,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - version: ['1.9', '1.10', '1.11'] + version: ['1.10', '1.11'] steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@latest diff --git a/Project.toml b/Project.toml index a6e5e1b992..c6ac08ac78 100644 --- a/Project.toml +++ b/Project.toml @@ -43,21 +43,21 @@ ClimaComms = "0.6" ClimaCore = "0.14.19" ClimaDiagnostics = "0.2.5" ClimaParams = "0.10.2" -ClimaUtilities = "0.1.15" +ClimaUtilities = "0.1.16" DataFrames = "1.4" Dates = "1" DocStringExtensions = "0.9" -Flux = "0.14.0" +Flux = "0.14.13" HTTP = "1.10" Insolation = "0.9.2" -Interpolations = "0.15" +Interpolations = "0.15.1" LazyArtifacts = "1" LinearAlgebra = "1" NCDatasets = "0.13.1, 0.14" -SciMLBase = "2" +SciMLBase = "2.34" StaticArrays = "1.5" StatsBase = "0.34" SurfaceFluxes = "0.11, 0.12" Thermodynamics = "0.12.4" cuDNN = "1" -julia = "1.9" +julia = "1.10" diff --git a/docs/make.jl b/docs/make.jl index 3a451bcd63..08d9695e4f 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -55,6 +55,7 @@ pages = Any[ "Tutorials" => tutorials, "Standalone models" => standalone_models, "Diagnostics" => diagnostics, + "Restarts" => "restarts.md", "Contribution guide" => "Contributing.md", "Repository structure" => "folderstructure.md", "APIs" => apis, diff --git a/docs/src/restarts.md b/docs/src/restarts.md new file mode 100644 index 0000000000..56491c6e84 --- /dev/null +++ b/docs/src/restarts.md @@ -0,0 +1,104 @@ +## Restarting Simulations + +`ClimaLand` provides functionality to save and load simulation checkpoints, +allowing you to restart simulations from a previous state. This is particularly +useful for long-running simulations or if you want to experiment with different +configurations starting from a specific point in the simulation. + + +### Saving Checkpoints + +To save a simulation checkpoint, you can use the `ClimaLand.save_checkpoint` +function. This function takes the current state `Y`, the simulation time `t`, +and the output directory as arguments. Optionally you can provide the +`ClimaLand` model object model. This will store the hash of the model in the +checkpoint file. You can use this information to ensure that you are restarting +the simulation with the same model that was used to generate the checkpoint. + +```julia +ClimaLand.save_checkpoint(Y, t, output_dir; model) +``` + +Most typically, this function is not called directly. Instead, it is called as a +callback. + +In ClimaLand, you can automate the process of saving checkpoints using the +`CheckpointCallback`. This callback allows you to specify the frequency at which +checkpoints are saved and handles the saving process during the simulation. + +To use the `CheckpointCallback`, you need to create an instance of it and pass +it to the solve function along with your other callbacks. + +Example: + +```julia + +# ... your ClimaLand simulation setup ... + +# Create a CheckpointCallback to save checkpoints every 6 hours +checkpoint_cb = CheckpointCallback(Dates.Hour(6), output_dir, start_date, t_start; model, dt) + +# Add the callback to the callback set +cb = SciMLBase.CallbackSet(checkpoint_cb, other_callbacks...) + +# Run the simulation with the callbacks +sol = SciMLBase.solve(prob, ode_algo; dt = Δt, callback = cb) + +# ... your ClimaLand simulation analysis ... +``` + +In this example, the `CheckpointCallback` will save a checkpoint every 6 hours +during the simulation. You can customize the checkpoint_frequency to control how +often checkpoints are saved. You can also pass the `ClimaLand` model object model +to store its hash in the checkpoint file. This information can be used later to +ensure that you are restarting the simulation with the same model that was used +to generate the checkpoint. + +If `dt` is passed, `CheckpointCallback` will also check that it is consistent +with the checkpoint frequency. + + +### Restarting from a Checkpoint + +To restart a simulation from a checkpoint, you can use the +`ClimaLand.find_restart` function to locate the most recent checkpoint file in +the output directory. Then, you can use the `ClimaLand.read_checkpoint` function +to load the state vector and simulation time from the checkpoint file. + +```julia +restart_file = ClimaLand.find_restart(output_dir) +Y, t = ClimaLand.read_checkpoint(restart_file; model) +``` + +### Output Structure + +`ClimaLand` utilizes the `OutputPathGenerator` from `ClimaUtilities` to manage +the output directory structure. By default, it uses the `ActiveLinkStyle`, which +creates a sequence of numbered subfolders within the base output directory. + +For example, if your base output directory is output, the following structure +will be created: +``` +output/ +├── output_0000/ +│ └── ... checkpoint files ... +├── output_0001/ +│ └── ... checkpoint files ... +├── output_0002/ +│ └── ... checkpoint files ... +└── output_active -> output_0002/ +``` + +The output_active symbolic link always points to the most recent output +subfolder, making it easy to access the latest simulation results. + +#### Checkpoint File Structure + +When using the `CheckpointCallback`, the checkpoints are saved as HDF5 files +within the numbered output subfolders. The files are named using the following +convention: +``` +day..hdf5 +``` +For example, a checkpoint saved at day 10, 3600 seconds after midnight would be +named `day10.3600.hdf5`. diff --git a/src/ClimaLand.jl b/src/ClimaLand.jl index 180d88defe..dc8f5dbbd2 100644 --- a/src/ClimaLand.jl +++ b/src/ClimaLand.jl @@ -18,6 +18,7 @@ import ClimaUtilities.SpaceVaryingInputs: SpaceVaryingInput import NCDatasets # Needed to load the ClimaUtilities.*VaryingInput using .Domains include("Artifacts.jl") +include("shared_utilities/checkpoints.jl") include("shared_utilities/utils.jl") include("shared_utilities/models.jl") include("shared_utilities/drivers.jl") diff --git a/src/shared_utilities/checkpoints.jl b/src/shared_utilities/checkpoints.jl new file mode 100644 index 0000000000..98568f543e --- /dev/null +++ b/src/shared_utilities/checkpoints.jl @@ -0,0 +1,128 @@ +import ClimaCore: InputOutput +import ClimaUtilities + +""" + ClimaLand.find_restart(output_dir) + +Find the most recent restart file in the specified output directory. + +This function utilizes `ClimaUtilities.OutputPathGenerator.detect_restart_file` +to locate the latest restart file within the output directory structure, +assuming the `ActiveLinkStyle` is used for managing output folders. + +# Arguments +- `output_dir`: The base output directory where the simulation results are stored. + +# Returns +- The path to the most recent restart file found, or `nothing` if no restart + file is found. +""" +function find_restart(output_dir) + return ClimaUtilities.OutputPathGenerator.detect_restart_file( + ClimaUtilities.OutputPathGenerator.ActiveLinkStyle(), + output_dir, + ) +end + +""" + _context_from_Y(Y) + +Try extracting the context from the FieldVector Y. + +Typically Y has a structure like: +``` +Y + .bucket + .T + .W + .Ws +``` +`_context_from_Y` tries obtaining the context from the first Field in the +hierarchy. +""" +function _context_from_Y(Y) + a_model_type = getproperty(Y, first(propertynames(Y))) + a_field = getproperty(a_model_type, first(propertynames(a_model_type))) + return ClimaComms.context(a_field) +end + +""" + ClimaLand.save_checkpoint(Y, t, output_dir; model = nothing, comms_ctx = ClimaComms.context(Y)) + +Save a simulation checkpoint to an HDF5 file. + +This function saves the current state of the simulation, including the state +vector `Y` and the current simulation time `t`, to an HDF5 file within the +specified output directory. + +# Arguments +- `Y`: The state of the simulation. +- `t`: The current simulation time. +- `output_dir`: The directory where the checkpoint file will be saved. +- `model` (Optional): The ClimaLand model object. If provided the hash of the model + will be stored in the checkpoint file. Defaults to `nothing`. This is used + to check for consistency. +- `comms_ctx` (Optional): The ClimaComms context. This is used for distributed I/O + operations. Defaults to the context extracted from the state vector `Y` or the `model`. +""" +function save_checkpoint( + Y, + t, + output_dir; + model = nothing, + comms_ctx = isnothing(model) ? _context_from_Y(Y) : + ClimaComms.context(model), +) + day = floor(Int, t / (60 * 60 * 24)) + sec = floor(Int, t % (60 * 60 * 24)) + output_file = joinpath(output_dir, "day$day.$sec.hdf5") + hdfwriter = InputOutput.HDF5Writer(output_file, comms_ctx) + # If model was passed, add its hash, otherwise add nothing + hash_model = isnothing(model) ? "nothing" : hash(model) + InputOutput.write_attributes!( + hdfwriter, + "/", + Dict("time" => t, "land_model_hash" => hash_model), + ) + InputOutput.write!(hdfwriter, Y, "Y") + Base.close(hdfwriter) + return nothing +end + +""" + ClimaLand.read_checkpoint(file_path; model = nothing, context = ClimaComms.context()) + +Read a simulation checkpoint from an HDF5 file. + +This function loads the simulation state from a previously saved checkpoint file. + +# Arguments +- `file_path`: The path to the HDF5 checkpoint file. +- `model` (Optional): The ClimaLand model object. If provided the hash of the model + stored in the checkpoint file will be compared with the hash of the provided + model and a warning will be issued if they don't match. Defaults to `nothing`. +- `context` (Optional): The ClimaComms context. This is used for parallel I/O + operations. Defaults to the default ClimaComms context. + +# Returns +- `Y`: The state vector loaded from the checkpoint file. +- `t`: The simulation time loaded from the checkpoint file. +""" +function read_checkpoint( + file_path; + model = nothing, + context = isnothing(model) ? ClimaComms.context() : + ClimaComms.context(model), +) + hdfreader = InputOutput.HDF5Reader(file_path, context) + Y = InputOutput.read_field(hdfreader, "Y") + attributes = InputOutput.read_attributes(hdfreader, "/") + if !isnothing(model) + if hash(model) != attributes["land_model_hash"] + @warn "Restart file $(file_path) was constructed with a different land model" + end + end + t = attributes["time"] + Base.close(hdfreader) + return Y, t +end diff --git a/src/shared_utilities/models.jl b/src/shared_utilities/models.jl index 1c8584f455..936dbeea86 100644 --- a/src/shared_utilities/models.jl +++ b/src/shared_utilities/models.jl @@ -447,6 +447,26 @@ function initialize(model::AbstractModel{FT}) where {FT} return Y, p, coords end + +""" + initialize_from_checkpoint(restart_file; model::AbstractModel) + +Creates the prognostic and auxiliary states structures, but with unset +values; constructs and returns the coordinates for the `model` domain. +We may need to consider this default more as we add diverse components and +`Simulations`. + +TODO: Combine this function with initialize. We don't really need two. +""" +function initialize_from_checkpoint(restart_file; model) + Y, t_checkpoint = read_checkpoint(restart_file; model) + coords = Domains.coordinates(model) + p = initialize_auxiliary(model, coords) + p = add_drivers_to_cache(p, model, coords) + return Y, p, coords, t_checkpoint +end + + function ClimaComms.context(model::AbstractModel) if :domain ∈ propertynames(model) return ClimaComms.context(model.domain) diff --git a/src/shared_utilities/utils.jl b/src/shared_utilities/utils.jl index fa69ee81c0..abbaa27af3 100644 --- a/src/shared_utilities/utils.jl +++ b/src/shared_utilities/utils.jl @@ -1,5 +1,7 @@ import ClimaCore import SciMLBase +import ClimaDiagnostics.Schedules: EveryCalendarDtSchedule +import Dates export FTfromY @@ -265,6 +267,75 @@ function DriverUpdateCallback(updateat::Vector{FT}, updatefunc) where {FT} ) end +""" + CheckpointCallback(checkpoint_frequency::Union{AbstractFloat, Dates.Period}, + output_dir, start_date, t_start; model, dt) + +Constructs a DiscreteCallback which saves the state to disk with the +`save_checkpoint` function. + +# Arguments +- `checkpoint_frequency`: The frequency at which checkpoints are saved. Can be + specified as a float (in seconds) or a `Dates.Period`. +- `output_dir`: The directory where the checkpoint files will be saved. +- `start_date`: The start date of the simulation. +- `t_start`: The starting time of the simulation (in seconds). +- `model`: The ClimaLand model object. +- `dt`: The timestep of the model (optional), used to check for consistency. + +The callback uses `ClimaDiagnostics.EveryCalendarDtSchedule` to determine when +to save checkpoints based on the `checkpoint_frequency`. The schedule is +initialized with the `start_date` and `t_start` to ensure that the first +checkpoint is saved at the correct time. + +The `save_checkpoint` function is called with the current state vector `u`, the +current time `t`, and the `output_dir` to save the checkpoint to disk. +""" +function CheckpointCallback( + checkpoint_frequency::Union{AbstractFloat, Dates.Period}, + output_dir, + start_date, + t_start; + model, + dt = nothing, +) + # TODO: Move to a more general callback system. For the time being, we use + # the ClimaDiagnostics one because it is flexible and it supports calendar + # dates. + + if checkpoint_frequency isa AbstractFloat + # Assume it is in seconds, but go through Millisecond to support + # fractional seconds + checkpoint_frequency_period = + Dates.Millisecond(1000checkpoint_frequency) + else + checkpoint_frequency_period = checkpoint_frequency + end + + schedule = EveryCalendarDtSchedule( + checkpoint_frequency_period; + reference_date = start_date, + date_last = start_date + Dates.Millisecond(1000t_start), + ) + + if !isnothing(dt) + dt_period = Dates.Millisecond(1000dt) + if !isivisible(checkpoint_frequency_period / dt_period) + @warn "Checkpoint frequency ($(checkpoint_frequency_period)) is not an integer multiple of dt $(dt_period)" + end + end + + cond = let schedule = schedule + (u, t, integrator) -> schedule(integrator) + end + affect! = let output_dir = output_dir, model = model + (integrator) -> + save_checkpoint(integrator.u, integrator.t, output_dir; model) + end + + SciMLBase.DiscreteCallback(cond, affect!) +end + """ driver_initialize(cb, u, t, integrator) @@ -278,8 +349,9 @@ end update_condition(updateat) This function returns a function with the type signature expected by -`SciMLBase.DiscreteCallback`, and determines whether `affect!` gets -called in the callback. This implementation simply checks if the current time of the simulation is within the (inclusive) bounds of `updateat`. +`SciMLBase.DiscreteCallback`, and determines whether `affect!` gets called in +the callback. This implementation simply checks if the current time of the +simulation is within the (inclusive) bounds of `updateat`. """ update_condition(updateat) = (_, t, _) -> t >= minimum(updateat) && t <= maximum(updateat) @@ -368,3 +440,55 @@ condition(saveat) = (_, t, _) -> t in saveat function FTfromY(Y::ClimaCore.Fields.FieldVector) return eltype(Y) end + + +""" + isdivisible(dt_large::Dates.Period, dt_small::Dates.Period) + +Check if two periods are evenly divisible, i.e., if the larger period can be +expressed as an integer multiple of the smaller period. + +In this, take into account the case when periods do not have fixed size, e.g., +one month is a variable number of days. + +# Examples +``` +julia> isdivisible(Dates.Year(1), Dates.Month(1)) +true + +julia> isdivisible(Dates.Month(1), Dates.Day(1)) +true + +julia> isdivisible(Dates.Month(1), Dates.Week(1)) +false +``` + +## Notes + +Not all the combinations are fully implemented. If something is missing, please +consider adding it. +""" +function isdivisible(dt_large::Dates.Period, dt_small::Dates.Period) + @warn "The combination $(typeof(dt_large)) and $(dt_small) was not covered. Please add a method to handle this case." + return false +end + +# For FixedPeriod and OtherPeriod, it is easy, we can directly divide the two +# (as long as they are both the same) +function isdivisible(dt_large::Dates.FixedPeriod, dt_small::Dates.FixedPeriod) + return isinteger(dt_large / dt_small) +end + +function isdivisible(dt_large::Dates.OtherPeriod, dt_small::Dates.OtherPeriod) + return isinteger(dt_large / dt_small) +end + +function isdivisible( + dt_large::Union{Dates.Month, Dates.Year}, + dt_small::Dates.FixedPeriod, +) + # The only case where periods are commensurate for Month/Year is when we + # have a Day or an integer divisor of a day. (Note that 365 and 366 don't + # have any common divisor) + return isinteger(Dates.Day(1) / dt_small) +end diff --git a/test/Project.toml b/test/Project.toml index 1c5efc4c30..4a79774125 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -23,6 +23,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab" NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" diff --git a/test/runtests.jl b/test/runtests.jl index 01f5737c6b..46f331e71c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -37,6 +37,9 @@ end @safetestset "Bucket soil tests" begin include("standalone/Bucket/soil_bucket_tests.jl") end +@safetestset "Restart tests" begin + include("standalone/Bucket/restart.jl") +end # Standalone Snow model tests @safetestset "Snow parameterization tests" begin diff --git a/test/standalone/Bucket/restart.jl b/test/standalone/Bucket/restart.jl new file mode 100644 index 0000000000..73ccc59a29 --- /dev/null +++ b/test/standalone/Bucket/restart.jl @@ -0,0 +1,102 @@ +using Test +import ClimaComms +ClimaComms.@import_required_backends +import ClimaLand +import ClimaParams +import SciMLBase +import ClimaTimeSteppers +import Dates +import ClimaUtilities + +# First, run a simulation for 1 hour +FT = Float32 +t0 = 0.0 +Δt = 3600.0 +tf = 3Δt +start_date = Dates.DateTime(2005) +root_path = "bucket_restart" +output_dir = ClimaUtilities.OutputPathGenerator.generate_output_path(root_path) + +bucket_domain = ClimaLand.SphericalShell(; + radius = FT(100), + depth = FT(3.5), + nelements = (1, 10), + npolynomial = 1, +) + +bucket_atmos, bucket_rad = ClimaLand.prescribed_analytic_forcing(FT) +τc = FT(1.0) +α_bareground_func = (coordinate_point) -> 0.2 +α_snow = FT(0.8) +z_0m = FT(1e-2) +z_0b = FT(1e-3) +albedo = ClimaLand.Bucket.PrescribedBaregroundAlbedo{FT}( + α_snow, + α_bareground_func, + bucket_domain.space.surface, +) +bucket_parameters = + ClimaLand.Bucket.BucketModelParameters(FT; albedo, z_0m, z_0b, τc) + +model = ClimaLand.Bucket.BucketModel( + parameters = bucket_parameters, + domain = bucket_domain, + atmosphere = bucket_atmos, + radiation = bucket_rad, +) + +Y, p, coords = ClimaLand.initialize(model) +Y.bucket.T .= 280.0 +Y.bucket.W .= 0.05 +Y.bucket.Ws .= 0.0 +Y.bucket.σS .= 0.08 + +exp_tendency! = ClimaLand.make_exp_tendency(model) +set_initial_cache! = ClimaLand.make_set_initial_cache(model) +set_initial_cache!(p, Y, t0) + +prob = SciMLBase.ODEProblem( + ClimaTimeSteppers.ClimaODEFunction((T_exp!) = exp_tendency!), + Y, + (t0, tf), + p, +) +updateat = collect(t0:Δt:tf) +checkpoint_frequency = 2Δt +drivers = ClimaLand.get_drivers(model) +updatefunc = ClimaLand.make_update_drivers(drivers) +driver_cb = ClimaLand.DriverUpdateCallback(updateat, updatefunc) +checkpoint_cb = ClimaLand.CheckpointCallback( + checkpoint_frequency, + output_dir, + start_date, + t0; + model, +) +cb = SciMLBase.CallbackSet(driver_cb, checkpoint_cb) + +timestepper = ClimaTimeSteppers.RK4() +ode_algo = ClimaTimeSteppers.ExplicitAlgorithm(timestepper) + +sol = SciMLBase.solve(prob, ode_algo; dt = Δt, callback = cb) + +# Now, let's restart from the checkpoint (we have the pass the root_path, not the output_dir) +restart_file = ClimaLand.find_restart(root_path) +Y_restart, p_restart, _, t_restart = + ClimaLand.initialize_from_checkpoint(restart_file; model) +set_initial_cache!(p_restart, Y_restart, t_restart) +prob_restart = SciMLBase.ODEProblem( + ClimaTimeSteppers.ClimaODEFunction((T_exp!) = exp_tendency!), + Y_restart, + (t_restart, tf), + p_restart, +) +updateat_restarted = collect(t_restart:Δt:tf) +driver_cb_restarted = + ClimaLand.DriverUpdateCallback(updateat_restarted, updatefunc) +cb_restarted = SciMLBase.CallbackSet(driver_cb_restarted) +sol_restarted = + SciMLBase.solve(prob_restart, ode_algo; dt = Δt, callback = cb_restarted) +for p in propertynames(Y.bucket) + @test getproperty(Y.bucket, p) == getproperty(Y_restart.bucket, p) +end