Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds in basic control net infrastructure and some tests #174

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

bischtob
Copy link
Collaborator

@bischtob bischtob commented Feb 23, 2024

We would like a basic control net setup that allows for custom conditioning of a pretrained network. Here, we build towards that end.

  • ControlNet struct & tests
  • Net ControlledNCSN that can use any ControlNet
  • Tests

@bischtob bischtob self-assigned this Feb 23, 2024
@bischtob bischtob added the enhancement New feature or request label Feb 23, 2024
@sandreza
Copy link

What's left to do here? Should I run some GPU tests to see if everything carries through?

CliMAgen.ControlledNoiseConditionalScoreNetwork

The struct containing the parameters and layers
of the Noise Conditional Score Network architecture,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

of the Controlled Noise Conditional Score Network

end

# Mean processing of noised variable channels
if net.mean_bypass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the embedded control/condition also be used by the mean bypass layer?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(we do use the context in the mean bypass in our other version, though it's only the mean of the context that is used)

"The layers of the network"
layers::NamedTuple
"A control network to condition the output of the U-net"
control_net::N
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like the output of the control net needs to be a vector of length embed_dim (same as for the time embedding). Is this correct? Since the control net is made exterior to the NCSN, where embed_dim is defined, we should maybe just add some documentation stating this, or maybe an @Assert statement somewhere, if we can?

Copy link
Member

@kmdeck kmdeck left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it looks great, I left a couple comments that might need to be addressed first before a merge.

thank you!

@kmdeck
Copy link
Member

kmdeck commented Feb 24, 2024

oh, one other question. the control net itself is either pretrained or else can be trained with the rest of the network, it looks like. is that correct?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants