-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add NeighborLoader #497
base: master
Are you sure you want to change the base?
Conversation
63c6711
to
b8f23d4
Compare
Hi @CarloLucibello I just wanted to check if it's ok to add |
It is fine, |
Now dependent on & blocked by #499 |
4d1e9be
to
5e7544c
Compare
|
||
# Returns: | ||
A vector of sampled neighbor node indices. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't attach docstrings to internal methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you marked as resolve but the docstring is still here
Co-authored-by: Carlo Lucibello <[email protected]>
@CarloLucibello ready for another look |
# Usage | ||
```julia | ||
loader = NeighborLoader(graph; num_neighbors=[10, 5], input_nodes=[1, 2, 3], num_layers=2) | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here one could show that the loader can be iterated over, something like
for g in loader
# do something
end
Co-authored-by: Carlo Lucibello <[email protected]>
@@ -0,0 +1,110 @@ | |||
""" | |||
struct NeighborLoader |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
struct NeighborLoader | |
NeighborLoader(graph; num_neighbors, input_nodes, num_layers, [batch_size]) |
end | ||
|
||
function NeighborLoader(graph::GNNGraph; num_neighbors::Vector{Int}, input_nodes::Vector{Int}, num_layers::Int, batch_size::Union{Int, Nothing}=nothing) | ||
return NeighborLoader(graph, num_neighbors, input_nodes, num_layers, batch_size === nothing ? length(input_nodes) : batch_size, Dict{Int, Vector{Int}}()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lines too long, break them
|
||
# Returns: | ||
A vector of sampled neighbor node indices. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you marked as resolve but the docstring is still here
neighbors_cache::Dict{Int, Vector{Int}} # Cache neighbors to avoid recomputation | ||
end | ||
|
||
function NeighborLoader(graph::GNNGraph; num_neighbors::Vector{Int}, input_nodes::Vector{Int}, num_layers::Int, batch_size::Union{Int, Nothing}=nothing) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input_node
could have default value = nothing
. If not passed it should consider all nodes in the graph.
@@ -9,6 +9,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" | |||
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" | |||
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c" | |||
GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48" | |||
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should add the corresponding compat bound
@@ -10,6 +10,7 @@ using NNlib: scatter, gather | |||
using ChainRulesCore | |||
using Reexport | |||
using MLUtils: zeros_like | |||
using Graphs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using Graphs | |
using Graphs: Graphs |
let's not import all symbols
This PR adds NeighborLoader functionality for big graphs.
This is similar to https://pytorch-geometric.readthedocs.io/en/latest/tutorial/neighbor_loader.html
PR with documentation will follow.