Replies: 2 comments
-
^ bumping. thank you |
Beta Was this translation helpful? Give feedback.
0 replies
-
FullyConnectedTensorProduct in e3nn-jax: x1 = e3nn.normal("0e + 1o", jax.random.PRNGKey(0), ())
x2 = e3nn.normal("0e + 1o", jax.random.PRNGKey(0), ())
y = e3nn.tensor_product(x1, x2)
lin = e3nn.flax.Linear("3x0e")
w = lin.init(jax.random.PRNGKey(0), y)
lin.apply(w, y) Gate Functionality in e3nn-jax: input = e3nn.IrrepsArray(
"0e + 0o + 1o + 1o + 0e + 0o",
jnp.array([1.0, 1.0, 10.0, 0.0, 0.0, 10.0, 0.0, 0.0, 1.0, 1.0]),
)
# input = [se, so, v1, v2, ge, go]
output = e3nn.gate(
input,
even_act=jnp.tanh,
odd_act=jnp.tanh,
even_gate_act=jnp.tanh,
odd_gate_act=jnp.tanh,
)
# output = even_act(se), odd_act(so), even_gate_act(ge) * v1, odd_gate_act(go) * v2 |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I know there has been some previous discussions between mapping of the FullyConnectedTensorProduct in e3nn to an equivalent form in e3nn-jax. I was seeking some further clarification on this topic as well as the use of Gate.
FullyConnectedTensorProduct in e3nn vs. e3nn-jax:
In e3nn, I employ FullyConnectedTensorProduct in the following manner:
Subsequently, I utilize tp for various operations such as tp.weight_views(), tp.instructions, tp.weight_numel, and invoking it directly with tp(data1, data2). Could you please provide detailed guidance on achieving these operations within e3nn-jax? Specifically, how can we replicate the FullyConnectedTensorProduct functionality and the associated methods in e3nn-jax?
Gate Functionality in e3nn vs. e3nn-jax:
Regarding the Gate mechanism, I understand that e3nn-jax offers enhanced control over its behavior compared to e3nn. In e3nn, my implementation looks something like this:
I'm particularly interested in how to translate this Gate configuration to e3nn-jax, factoring in the additional control and customization capabilities it provides.
Thanks!
Beta Was this translation helpful? Give feedback.
All reactions