-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds in basic control net infrastructure and some tests #174
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -994,3 +994,262 @@ Carries on the spatial convolution respecting periodicity at the boundaries. | |
function (layer::CircularConv)(x) | ||
layer.conv(NNlib.pad_circular(x, layer.pad)) | ||
end | ||
|
||
""" | ||
CliMAgen.ControlNet | ||
|
||
The struct containing the parameters and layers of the ControlNet architecture. | ||
""" | ||
struct ControlNet{N} | ||
net::N | ||
trainable::Bool # whether the network is trainable | ||
end | ||
|
||
""" | ||
ControlNet(net::N; trainable::Bool=false) | ||
|
||
Creates a ControlNet with the given neural network `net` and whether it is trainable. | ||
""" | ||
function ControlNet(net::N; trainable::Bool=false) where N | ||
return ControlNet{N}(net, trainable) | ||
end | ||
|
||
@functor ControlNet | ||
|
||
""" | ||
(net::ControlNet)(x) | ||
|
||
Evaluates the neural network of the ControlNet model on `x`. | ||
""" | ||
function (c::ControlNet)(x) | ||
return c.net(x) | ||
end | ||
|
||
""" | ||
Flux.params(::ControlNet) | ||
Returns the trainable parameters of the ControlNet). | ||
""" | ||
Flux.params(c::ControlNet) = c.trainable ? Flux.params(c.net) : nothing | ||
|
||
""" | ||
CliMAgen.ControlledNoiseConditionalScoreNetwork | ||
|
||
The struct containing the parameters and layers | ||
of the Noise Conditional Score Network architecture, | ||
with the option to include a mean-bypass layer. | ||
|
||
# References | ||
Unet: https://arxiv.org/abs/1505.04597 | ||
""" | ||
struct ControlledNoiseConditionalScoreNetwork{N} | ||
"The layers of the network" | ||
layers::NamedTuple | ||
"A control network to condition the output of the U-net" | ||
control_net::N | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems like the output of the control net needs to be a vector of length embed_dim (same as for the time embedding). Is this correct? Since the control net is made exterior to the NCSN, where embed_dim is defined, we should maybe just add some documentation stating this, or maybe an @Assert statement somewhere, if we can? |
||
"A boolean indicating if a mean-bypass layer should be used" | ||
mean_bypass::Bool | ||
"A boolean indicating if the output of the mean-bypass layer should be scaled" | ||
scale_mean_bypass::Bool | ||
"A boolean indicating if the input is demeaned before being passed to the U-net" | ||
shift_input::Bool | ||
"A boolean indicating if the output of the Unet is demeaned" | ||
shift_output::Bool | ||
"A boolean indicating if a groupnorm should be used in the mean-bypass layer" | ||
gnorm::Bool | ||
end | ||
|
||
function ControlledNoiseConditionalScoreNetwork(; control_net, | ||
mean_bypass=false, | ||
scale_mean_bypass=false, | ||
shift_input=false, | ||
shift_output=false, | ||
gnorm=false, | ||
nspatial=2, | ||
dropout_p=0.0f0, | ||
num_residual=8, | ||
noised_channels=1, | ||
channels=[32, 64, 128, 256], | ||
embed_dim=256, | ||
scale=30.0f0, | ||
periodic=false, | ||
proj_kernelsize=3, | ||
outer_kernelsize=3, | ||
middle_kernelsize=3, | ||
inner_kernelsize=3) | ||
if scale_mean_bypass & !mean_bypass | ||
@error("Attempting to scale the mean bypass term without adding in a mean bypass connection.") | ||
end | ||
if gnorm & !mean_bypass | ||
@error("Attempting to gnorm without adding in a mean bypass connection.") | ||
end | ||
inchannels = noised_channels | ||
outchannels = noised_channels | ||
|
||
# Mean processing as indicated by boolean mean_bypass | ||
if mean_bypass | ||
if gnorm | ||
mean_bypass_layers = ( | ||
mean_skip_1 = Conv((1, 1), inchannels => embed_dim), | ||
mean_skip_2 = Conv((1, 1), embed_dim => embed_dim), | ||
mean_skip_3 = Conv((1, 1), embed_dim => outchannels), | ||
mean_gnorm_1 = GroupNorm(embed_dim, 32, swish), | ||
mean_gnorm_2 = GroupNorm(embed_dim, 32, swish), | ||
mean_dense_1 = Dense(embed_dim, embed_dim), | ||
mean_dense_2 = Dense(embed_dim, embed_dim), | ||
) | ||
else | ||
mean_bypass_layers = ( | ||
mean_skip_1 = Conv((1, 1), inchannels => embed_dim), | ||
mean_skip_2 = Conv((1, 1), embed_dim => embed_dim), | ||
mean_skip_3 = Conv((1, 1), embed_dim => outchannels), | ||
mean_dense_1 = Dense(embed_dim, embed_dim), | ||
mean_dense_2 = Dense(embed_dim, embed_dim), | ||
) | ||
end | ||
else | ||
mean_bypass_layers = () | ||
end | ||
|
||
# Lifting/Projection layers depend on periodicity of data | ||
if periodic | ||
conv1 = CircularConv(3, nspatial, inchannels => channels[1] ; stride=1) | ||
tconv1 = CircularConv(proj_kernelsize, nspatial, channels[1] + channels[1] => outchannels; stride=1) | ||
else | ||
conv1=Conv((3, 3), inchannels => channels[1], stride=1, pad=SamePad()) | ||
tconv1=Conv((proj_kernelsize, proj_kernelsize), channels[1] + channels[1] => outchannels, stride=1, pad=SamePad()) | ||
end | ||
|
||
layers = (gaussfourierproj=GaussianFourierProjection(embed_dim, scale), | ||
linear=Dense(embed_dim, embed_dim, swish), | ||
|
||
# Lifting | ||
conv1=conv1, | ||
dense1=Dense(embed_dim, channels[1]), | ||
control_dense1=Dense(embed_dim, channels[1]), | ||
gnorm1=GroupNorm(channels[1], 4, swish), | ||
|
||
# Encoding | ||
conv2=Downsampling(channels[1] => channels[2], nspatial, kernel_size=3, periodic=periodic), | ||
dense2=Dense(embed_dim, channels[2]), | ||
control_dense2=Dense(embed_dim, channels[2]), | ||
gnorm2=GroupNorm(channels[2], 32, swish), | ||
|
||
conv3=Downsampling(channels[2] => channels[3], nspatial, kernel_size=3, periodic=periodic), | ||
dense3=Dense(embed_dim, channels[3]), | ||
control_dense3=Dense(embed_dim, channels[3]), | ||
gnorm3=GroupNorm(channels[3], 32, swish), | ||
|
||
conv4=Downsampling(channels[3] => channels[4], nspatial, kernel_size=3, periodic=periodic), | ||
dense4=Dense(embed_dim, channels[4]), | ||
control_dense4=Dense(embed_dim, channels[4]), | ||
|
||
# Residual Blocks | ||
resnet_blocks = | ||
[ResnetBlockNCSN(channels[end], nspatial, embed_dim; p = dropout_p, periodic=periodic) for _ in range(1, length=num_residual)], | ||
|
||
# Decoding | ||
gnorm4=GroupNorm(channels[4], 32, swish), | ||
tconv4=Upsampling(channels[4] => channels[3], nspatial, kernel_size=inner_kernelsize, periodic=periodic), | ||
denset4=Dense(embed_dim, channels[3]), | ||
control_denset4=Dense(embed_dim, channels[3]), | ||
tgnorm4=GroupNorm(channels[3], 32, swish), | ||
|
||
tconv3=Upsampling(channels[3]+channels[3] => channels[2], nspatial, kernel_size=middle_kernelsize, periodic=periodic), | ||
denset3=Dense(embed_dim, channels[2]), | ||
control_denset3=Dense(embed_dim, channels[2]), | ||
tgnorm3=GroupNorm(channels[2], 32, swish), | ||
|
||
tconv2=Upsampling(channels[2]+channels[2] => channels[1], nspatial, kernel_size=outer_kernelsize, periodic=periodic), | ||
denset2=Dense(embed_dim, channels[1]), | ||
control_denset2=Dense(embed_dim, channels[1]), | ||
tgnorm2=GroupNorm(channels[1], 32, swish), | ||
|
||
# Projection | ||
tconv1=tconv1, | ||
mean_bypass_layers... | ||
) | ||
|
||
return ControlledNoiseConditionalScoreNetwork(layers, control_net, mean_bypass, scale_mean_bypass, shift_input, shift_output, gnorm) | ||
end | ||
|
||
@functor ControlledNoiseConditionalScoreNetwork | ||
|
||
""" | ||
(net::ControlledNoiseConditionalScoreNetwork)(x, c, t) | ||
|
||
Evaluates the neural network of the NoiseConditionalScoreNetwork | ||
model on (x,c,t), where `x` is the tensor of noised input, | ||
`c` is the tensor of contextual input, and `t` is a tensor of times. | ||
""" | ||
function (net::ControlledNoiseConditionalScoreNetwork)(x, c, t) | ||
# Get size of spatial dimensions | ||
nspatial = ndims(x) - 2 | ||
|
||
# Embeddings | ||
embed = net.layers.gaussfourierproj(t) | ||
embed = net.layers.linear(embed) | ||
control_embed = net.control_net(c) | ||
|
||
# Encoder | ||
if net.shift_input | ||
h1 = x .- mean(x, dims=(1:nspatial)) # remove mean of noised variables before input | ||
else | ||
h1 = x | ||
end | ||
h1 = net.layers.conv1(h1) | ||
h1 = h1 .+ expand_dims(net.layers.dense1(embed) .+ net.layers.control_dense1(control_embed), nspatial) | ||
h1 = net.layers.gnorm1(h1) | ||
h2 = net.layers.conv2(h1) | ||
h2 = h2 .+ expand_dims(net.layers.dense2(embed) .+ net.layers.control_dense2(control_embed), nspatial) | ||
h2 = net.layers.gnorm2(h2) | ||
h3 = net.layers.conv3(h2) | ||
h3 = h3 .+ expand_dims(net.layers.dense3(embed) .+ net.layers.control_dense3(control_embed), nspatial) | ||
h3 = net.layers.gnorm3(h3) | ||
h4 = net.layers.conv4(h3) | ||
h4 = h4 .+ expand_dims(net.layers.dense4(embed) .+ net.layers.control_dense4(control_embed), nspatial) | ||
|
||
# middle | ||
h = h4 | ||
for block in net.layers.resnet_blocks | ||
h = block(h, embed .+ control_embed) # add in control embedding, can perhaps be done better. | ||
bischtob marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
|
||
# Decoder | ||
h = net.layers.gnorm4(h) | ||
h = net.layers.tconv4(h) | ||
h = h .+ expand_dims(net.layers.denset4(embed) .+ net.layers.control_denset4(control_embed), nspatial) | ||
h = net.layers.tgnorm4(h) | ||
h = net.layers.tconv3(cat(h, h3; dims=nspatial+1)) | ||
h = h .+ expand_dims(net.layers.denset3(embed) .+ net.layers.control_denset3(control_embed), nspatial) | ||
h = net.layers.tgnorm3(h) | ||
h = net.layers.tconv2(cat(h, h2, dims=nspatial+1)) | ||
h = h .+ expand_dims(net.layers.denset2(embed) .+ net.layers.control_denset2(control_embed), nspatial) | ||
h = net.layers.tgnorm2(h) | ||
h = net.layers.tconv1(cat(h, h1, dims=nspatial+1)) | ||
if net.shift_output | ||
h = h .- mean(h, dims=(1:nspatial)) # remove mean after output | ||
end | ||
|
||
# Mean processing of noised variable channels | ||
if net.mean_bypass | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should the embedded control/condition also be used by the mean bypass layer? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (we do use the context in the mean bypass in our other version, though it's only the mean of the context that is used) |
||
hm = net.layers.mean_skip_1(mean(x, dims=(1:nspatial))) | ||
hm = hm .+ expand_dims(net.layers.mean_dense_1(embed), nspatial) | ||
if net.gnorm | ||
hm = net.layers.mean_gnorm_1(hm) | ||
end | ||
hm = net.layers.mean_skip_2(hm) | ||
hm = hm .+ expand_dims(net.layers.mean_dense_2(embed), nspatial) | ||
if net.gnorm | ||
hm = net.layers.mean_gnorm_2(hm) | ||
end | ||
hm = net.layers.mean_skip_3(hm) | ||
if net.scale_mean_bypass | ||
scale = convert(eltype(x), sqrt(prod(size(x)[1:nspatial]))) | ||
hm = hm ./ scale | ||
end | ||
# Add back in noised channel mean to noised channel spatial variatons | ||
return h .+ hm | ||
else | ||
return h | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
of the Controlled Noise Conditional Score Network