Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds in basic control net infrastructure and some tests #174

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/CliMAgen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ export struct2dict, dict2nt
export VarianceExplodingSDE
export drift, diffusion, marginal_prob, score
export vanilla_score_matching_loss, score_matching_loss
export NoiseConditionalScoreNetwork, DenoisingDiffusionNetwork, ResnetBlockDDN, ResnetBlockNCSN, AttentionBlock, CircularConv
export NoiseConditionalScoreNetwork, DenoisingDiffusionNetwork, ResnetBlockDDN, ResnetBlockNCSN, AttentionBlock, CircularConv, ControlledNoiseConditionalScoreNetwork
export WarmupSchedule, ExponentialMovingAverage
export train!, load_model_and_optimizer, save_model_and_optimizer
export setup_sampler, Euler_Maruyama_sampler, Euler_Maruyama_ld_sampler, predictor_corrector_sampler
export MeanSpatialScaling, StandardScaling, apply_preprocessing, invert_preprocessing

end
259 changes: 259 additions & 0 deletions src/networks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -994,3 +994,262 @@ Carries on the spatial convolution respecting periodicity at the boundaries.
function (layer::CircularConv)(x)
layer.conv(NNlib.pad_circular(x, layer.pad))
end

"""
CliMAgen.ControlNet

The struct containing the parameters and layers of the ControlNet architecture.
"""
struct ControlNet{N}
net::N
trainable::Bool # whether the network is trainable
end

"""
ControlNet(net::N; trainable::Bool=false)

Creates a ControlNet with the given neural network `net` and whether it is trainable.
"""
function ControlNet(net::N; trainable::Bool=false) where N
return ControlNet{N}(net, trainable)
end

@functor ControlNet

"""
(net::ControlNet)(x)

Evaluates the neural network of the ControlNet model on `x`.
"""
function (c::ControlNet)(x)
return c.net(x)
end

"""
Flux.params(::ControlNet)
Returns the trainable parameters of the ControlNet).
"""
Flux.params(c::ControlNet) = c.trainable ? Flux.params(c.net) : nothing

"""
CliMAgen.ControlledNoiseConditionalScoreNetwork

The struct containing the parameters and layers
of the Noise Conditional Score Network architecture,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

of the Controlled Noise Conditional Score Network

with the option to include a mean-bypass layer.

# References
Unet: https://arxiv.org/abs/1505.04597
"""
struct ControlledNoiseConditionalScoreNetwork{N}
"The layers of the network"
layers::NamedTuple
"A control network to condition the output of the U-net"
control_net::N
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like the output of the control net needs to be a vector of length embed_dim (same as for the time embedding). Is this correct? Since the control net is made exterior to the NCSN, where embed_dim is defined, we should maybe just add some documentation stating this, or maybe an @Assert statement somewhere, if we can?

"A boolean indicating if a mean-bypass layer should be used"
mean_bypass::Bool
"A boolean indicating if the output of the mean-bypass layer should be scaled"
scale_mean_bypass::Bool
"A boolean indicating if the input is demeaned before being passed to the U-net"
shift_input::Bool
"A boolean indicating if the output of the Unet is demeaned"
shift_output::Bool
"A boolean indicating if a groupnorm should be used in the mean-bypass layer"
gnorm::Bool
end

function ControlledNoiseConditionalScoreNetwork(; control_net,
mean_bypass=false,
scale_mean_bypass=false,
shift_input=false,
shift_output=false,
gnorm=false,
nspatial=2,
dropout_p=0.0f0,
num_residual=8,
noised_channels=1,
channels=[32, 64, 128, 256],
embed_dim=256,
scale=30.0f0,
periodic=false,
proj_kernelsize=3,
outer_kernelsize=3,
middle_kernelsize=3,
inner_kernelsize=3)
if scale_mean_bypass & !mean_bypass
@error("Attempting to scale the mean bypass term without adding in a mean bypass connection.")
end
if gnorm & !mean_bypass
@error("Attempting to gnorm without adding in a mean bypass connection.")
end
inchannels = noised_channels
outchannels = noised_channels

# Mean processing as indicated by boolean mean_bypass
if mean_bypass
if gnorm
mean_bypass_layers = (
mean_skip_1 = Conv((1, 1), inchannels => embed_dim),
mean_skip_2 = Conv((1, 1), embed_dim => embed_dim),
mean_skip_3 = Conv((1, 1), embed_dim => outchannels),
mean_gnorm_1 = GroupNorm(embed_dim, 32, swish),
mean_gnorm_2 = GroupNorm(embed_dim, 32, swish),
mean_dense_1 = Dense(embed_dim, embed_dim),
mean_dense_2 = Dense(embed_dim, embed_dim),
)
else
mean_bypass_layers = (
mean_skip_1 = Conv((1, 1), inchannels => embed_dim),
mean_skip_2 = Conv((1, 1), embed_dim => embed_dim),
mean_skip_3 = Conv((1, 1), embed_dim => outchannels),
mean_dense_1 = Dense(embed_dim, embed_dim),
mean_dense_2 = Dense(embed_dim, embed_dim),
)
end
else
mean_bypass_layers = ()
end

# Lifting/Projection layers depend on periodicity of data
if periodic
conv1 = CircularConv(3, nspatial, inchannels => channels[1] ; stride=1)
tconv1 = CircularConv(proj_kernelsize, nspatial, channels[1] + channels[1] => outchannels; stride=1)
else
conv1=Conv((3, 3), inchannels => channels[1], stride=1, pad=SamePad())
tconv1=Conv((proj_kernelsize, proj_kernelsize), channels[1] + channels[1] => outchannels, stride=1, pad=SamePad())
end

layers = (gaussfourierproj=GaussianFourierProjection(embed_dim, scale),
linear=Dense(embed_dim, embed_dim, swish),

# Lifting
conv1=conv1,
dense1=Dense(embed_dim, channels[1]),
control_dense1=Dense(embed_dim, channels[1]),
gnorm1=GroupNorm(channels[1], 4, swish),

# Encoding
conv2=Downsampling(channels[1] => channels[2], nspatial, kernel_size=3, periodic=periodic),
dense2=Dense(embed_dim, channels[2]),
control_dense2=Dense(embed_dim, channels[2]),
gnorm2=GroupNorm(channels[2], 32, swish),

conv3=Downsampling(channels[2] => channels[3], nspatial, kernel_size=3, periodic=periodic),
dense3=Dense(embed_dim, channels[3]),
control_dense3=Dense(embed_dim, channels[3]),
gnorm3=GroupNorm(channels[3], 32, swish),

conv4=Downsampling(channels[3] => channels[4], nspatial, kernel_size=3, periodic=periodic),
dense4=Dense(embed_dim, channels[4]),
control_dense4=Dense(embed_dim, channels[4]),

# Residual Blocks
resnet_blocks =
[ResnetBlockNCSN(channels[end], nspatial, embed_dim; p = dropout_p, periodic=periodic) for _ in range(1, length=num_residual)],

# Decoding
gnorm4=GroupNorm(channels[4], 32, swish),
tconv4=Upsampling(channels[4] => channels[3], nspatial, kernel_size=inner_kernelsize, periodic=periodic),
denset4=Dense(embed_dim, channels[3]),
control_denset4=Dense(embed_dim, channels[3]),
tgnorm4=GroupNorm(channels[3], 32, swish),

tconv3=Upsampling(channels[3]+channels[3] => channels[2], nspatial, kernel_size=middle_kernelsize, periodic=periodic),
denset3=Dense(embed_dim, channels[2]),
control_denset3=Dense(embed_dim, channels[2]),
tgnorm3=GroupNorm(channels[2], 32, swish),

tconv2=Upsampling(channels[2]+channels[2] => channels[1], nspatial, kernel_size=outer_kernelsize, periodic=periodic),
denset2=Dense(embed_dim, channels[1]),
control_denset2=Dense(embed_dim, channels[1]),
tgnorm2=GroupNorm(channels[1], 32, swish),

# Projection
tconv1=tconv1,
mean_bypass_layers...
)

return ControlledNoiseConditionalScoreNetwork(layers, control_net, mean_bypass, scale_mean_bypass, shift_input, shift_output, gnorm)
end

@functor ControlledNoiseConditionalScoreNetwork

"""
(net::ControlledNoiseConditionalScoreNetwork)(x, c, t)

Evaluates the neural network of the NoiseConditionalScoreNetwork
model on (x,c,t), where `x` is the tensor of noised input,
`c` is the tensor of contextual input, and `t` is a tensor of times.
"""
function (net::ControlledNoiseConditionalScoreNetwork)(x, c, t)
# Get size of spatial dimensions
nspatial = ndims(x) - 2

# Embeddings
embed = net.layers.gaussfourierproj(t)
embed = net.layers.linear(embed)
control_embed = net.control_net(c)

# Encoder
if net.shift_input
h1 = x .- mean(x, dims=(1:nspatial)) # remove mean of noised variables before input
else
h1 = x
end
h1 = net.layers.conv1(h1)
h1 = h1 .+ expand_dims(net.layers.dense1(embed) .+ net.layers.control_dense1(control_embed), nspatial)
h1 = net.layers.gnorm1(h1)
h2 = net.layers.conv2(h1)
h2 = h2 .+ expand_dims(net.layers.dense2(embed) .+ net.layers.control_dense2(control_embed), nspatial)
h2 = net.layers.gnorm2(h2)
h3 = net.layers.conv3(h2)
h3 = h3 .+ expand_dims(net.layers.dense3(embed) .+ net.layers.control_dense3(control_embed), nspatial)
h3 = net.layers.gnorm3(h3)
h4 = net.layers.conv4(h3)
h4 = h4 .+ expand_dims(net.layers.dense4(embed) .+ net.layers.control_dense4(control_embed), nspatial)

# middle
h = h4
for block in net.layers.resnet_blocks
h = block(h, embed .+ control_embed) # add in control embedding, can perhaps be done better.
bischtob marked this conversation as resolved.
Show resolved Hide resolved
end

# Decoder
h = net.layers.gnorm4(h)
h = net.layers.tconv4(h)
h = h .+ expand_dims(net.layers.denset4(embed) .+ net.layers.control_denset4(control_embed), nspatial)
h = net.layers.tgnorm4(h)
h = net.layers.tconv3(cat(h, h3; dims=nspatial+1))
h = h .+ expand_dims(net.layers.denset3(embed) .+ net.layers.control_denset3(control_embed), nspatial)
h = net.layers.tgnorm3(h)
h = net.layers.tconv2(cat(h, h2, dims=nspatial+1))
h = h .+ expand_dims(net.layers.denset2(embed) .+ net.layers.control_denset2(control_embed), nspatial)
h = net.layers.tgnorm2(h)
h = net.layers.tconv1(cat(h, h1, dims=nspatial+1))
if net.shift_output
h = h .- mean(h, dims=(1:nspatial)) # remove mean after output
end

# Mean processing of noised variable channels
if net.mean_bypass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the embedded control/condition also be used by the mean bypass layer?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(we do use the context in the mean bypass in our other version, though it's only the mean of the context that is used)

hm = net.layers.mean_skip_1(mean(x, dims=(1:nspatial)))
hm = hm .+ expand_dims(net.layers.mean_dense_1(embed), nspatial)
if net.gnorm
hm = net.layers.mean_gnorm_1(hm)
end
hm = net.layers.mean_skip_2(hm)
hm = hm .+ expand_dims(net.layers.mean_dense_2(embed), nspatial)
if net.gnorm
hm = net.layers.mean_gnorm_2(hm)
end
hm = net.layers.mean_skip_3(hm)
if net.scale_mean_bypass
scale = convert(eltype(x), sqrt(prod(size(x)[1:nspatial])))
hm = hm ./ scale
end
# Add back in noised channel mean to noised channel spatial variatons
return h .+ hm
else
return h
end
end
31 changes: 31 additions & 0 deletions test/test_networks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,4 +193,35 @@ end
sum(net(x, c, t) .^ 2)
end
@test loss isa Real
end

@testset "ControlNet" begin
# constructor
net = Dense(10, 5)
controlnet = CliMAgen.ControlNet(net, trainable=true)
@test controlnet.net == net
@test controlnet.trainable == true

x = randn(10)
@test controlnet(x) == net(x)
end

@testset "ControlledNoiseConditionalScoreNetwork" begin
# with controlnet
control_net = CliMAgen.ControlNet(Dense(11, 256), trainable=true)
net = CliMAgen.ControlledNoiseConditionalScoreNetwork(control_net=control_net, noised_channels=2)
ps = Flux.params(net)
k = 5
x = rand(Float32, 2^k, 2^k, 2, 11)
c = rand(Float32, 11)
t = rand(Float32)

# forward pass
@test net(x, c, t) |> size == (2^k, 2^k, 2, 11)

# backward pass of dummy loss
loss, grad = Flux.withgradient(ps) do
sum(net(x, c, t) .^ 2)
end
@test loss isa Real
end