Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add NeighborLoader #497

Open
wants to merge 27 commits into
base: master
Choose a base branch
from
Open

feat: add NeighborLoader #497

wants to merge 27 commits into from

Conversation

askorupka
Copy link
Collaborator

This PR adds NeighborLoader functionality for big graphs.
This is similar to https://pytorch-geometric.readthedocs.io/en/latest/tutorial/neighbor_loader.html

PR with documentation will follow.

@askorupka
Copy link
Collaborator Author

Hi @CarloLucibello I just wanted to check if it's ok to add Graphs to deps instead of weakdeps, or do you see any better solutions instead.
I need it to use Graphs.neighbors function in graph subsampling.

GraphNeuralNetworks/src/samplers.jl Outdated Show resolved Hide resolved
GraphNeuralNetworks/src/samplers.jl Outdated Show resolved Hide resolved
GraphNeuralNetworks/src/samplers.jl Outdated Show resolved Hide resolved
GraphNeuralNetworks/src/samplers.jl Outdated Show resolved Hide resolved
GraphNeuralNetworks/src/samplers.jl Outdated Show resolved Hide resolved
GraphNeuralNetworks/src/samplers.jl Outdated Show resolved Hide resolved
src/samplers.jl Outdated Show resolved Hide resolved
@CarloLucibello
Copy link
Owner

Hi @CarloLucibello I just wanted to check if it's ok to add Graphs to deps instead of weakdeps, or do you see any better solutions instead.
I need it to use Graphs.neighbors function in graph subsampling.

It is fine, Graphs.jl is already a dependence through GNNGraphs.jl

@askorupka
Copy link
Collaborator Author

Now dependent on & blocked by #499

GraphNeuralNetworks/src/samplers.jl Outdated Show resolved Hide resolved
GraphNeuralNetworks/src/samplers.jl Outdated Show resolved Hide resolved

# Returns:
A vector of sampled neighbor node indices.
"""
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't attach docstrings to internal methods.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you marked as resolve but the docstring is still here

GraphNeuralNetworks/src/samplers.jl Outdated Show resolved Hide resolved
GraphNeuralNetworks/src/samplers.jl Outdated Show resolved Hide resolved
GraphNeuralNetworks/src/samplers.jl Show resolved Hide resolved
@askorupka
Copy link
Collaborator Author

@CarloLucibello ready for another look

# Usage
```julia
loader = NeighborLoader(graph; num_neighbors=[10, 5], input_nodes=[1, 2, 3], num_layers=2)
```
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here one could show that the loader can be iterated over, something like

for g in loader
  # do something
end

@@ -0,0 +1,110 @@
"""
struct NeighborLoader
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
struct NeighborLoader
NeighborLoader(graph; num_neighbors, input_nodes, num_layers, [batch_size])

end

function NeighborLoader(graph::GNNGraph; num_neighbors::Vector{Int}, input_nodes::Vector{Int}, num_layers::Int, batch_size::Union{Int, Nothing}=nothing)
return NeighborLoader(graph, num_neighbors, input_nodes, num_layers, batch_size === nothing ? length(input_nodes) : batch_size, Dict{Int, Vector{Int}}())
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lines too long, break them


# Returns:
A vector of sampled neighbor node indices.
"""
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you marked as resolve but the docstring is still here

neighbors_cache::Dict{Int, Vector{Int}} # Cache neighbors to avoid recomputation
end

function NeighborLoader(graph::GNNGraph; num_neighbors::Vector{Int}, input_nodes::Vector{Int}, num_layers::Int, batch_size::Union{Int, Nothing}=nothing)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_node could have default value = nothing. If not passed it should consider all nodes in the graph.

@@ -9,6 +9,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should add the corresponding compat bound

@@ -10,6 +10,7 @@ using NNlib: scatter, gather
using ChainRulesCore
using Reexport
using MLUtils: zeros_like
using Graphs
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
using Graphs
using Graphs: Graphs

let's not import all symbols

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants