Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add NeighborLoader #497

Open
wants to merge 27 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
349d99f
feat: init neighbor loader
askorupka Sep 11, 2024
07db4d0
feat: init neighbor loader
askorupka Sep 11, 2024
25945c7
feat: refine neighborloader
askorupka Sep 11, 2024
506d4c7
fix: refine neighborloader
askorupka Sep 11, 2024
991bf61
fix: refine neighborloader
askorupka Sep 11, 2024
c25bc1e
fix: refine neighborloader
askorupka Sep 11, 2024
fde10bb
fix: refine neighborloader
askorupka Sep 11, 2024
3656691
fix: refine neighborloader
askorupka Sep 11, 2024
9997fab
chore: add some comments
askorupka Sep 12, 2024
acf209c
chore: add TODO comments
askorupka Sep 17, 2024
0c4a653
feat: add tests, refine code
askorupka Sep 28, 2024
2035c5e
fix: add samplers.jl after rebase
askorupka Sep 28, 2024
ebebce9
chore: add docstrings
askorupka Sep 29, 2024
abf31cd
chore: Graphs to deps
askorupka Sep 29, 2024
bcdfa5e
chore: move using Graphs to main file
askorupka Sep 29, 2024
970d297
chore: readd Graphs to extras
askorupka Sep 29, 2024
b4c1ad7
chore: delete src/samplers.jl created by mistake
askorupka Sep 29, 2024
5e7544c
fix: add sampling.jl to docs
askorupka Oct 12, 2024
c9d412b
fix: add sampling.jl to docs
askorupka Oct 12, 2024
2d7bd0b
fix: add sampling.jl to docs
askorupka Oct 12, 2024
65aa564
fix: deduplicate function
askorupka Oct 12, 2024
61c5e39
fix: fix broken tests
askorupka Oct 24, 2024
aec5574
chore: remove printlns
askorupka Oct 24, 2024
e675086
Update GraphNeuralNetworks/src/samplers.jl
askorupka Oct 27, 2024
62f5d87
fix: remove docstrings where not needed
askorupka Oct 27, 2024
3ed22bf
chore: add ref to the paper
askorupka Oct 27, 2024
e4dc977
Update GraphNeuralNetworks/src/samplers.jl
askorupka Oct 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion GNNGraphs/src/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ function Graphs.induced_subgraph(graph::GNNGraph, nodes::Vector{Int})

node_map = Dict(node => i for (i, node) in enumerate(nodes))

edge_list = [collect(t) for t in zip(edge_index(graph)[1],edge_index(graph)[2])]

# Collect edges to add
source = Int[]
target = Int[]
Expand All @@ -188,7 +190,7 @@ function Graphs.induced_subgraph(graph::GNNGraph, nodes::Vector{Int})
push!(target, node_map[node])
push!(source, node_map[neighbor])

eindex = findfirst(x -> x == [neighbor, node], edge_index(graph))
eindex = findfirst(x -> x == [neighbor, node], edge_list)
push!(eindices, eindex)
end
end
Expand Down
1 change: 1 addition & 0 deletions GraphNeuralNetworks/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should add the corresponding compat bound

LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Expand Down
1 change: 1 addition & 0 deletions GraphNeuralNetworks/docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ makedocs(;
"Message Passing" => "api/messagepassing.md",
"Heterogeneous Graphs" => "api/heterograph.md",
"Temporal Graphs" => "api/temporalgraph.md",
"Samplers" => "api/samplers.md",
"Utils" => "api/utils.md",
],
"Developer Notes" => "dev.md",
Expand Down
14 changes: 14 additions & 0 deletions GraphNeuralNetworks/docs/src/api/samplers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
```@meta
CurrentModule = GraphNeuralNetworks
```

# Samplers


## Docs

```@autodocs
Modules = [GraphNeuralNetworks]
Pages = ["samplers.jl"]
Private = false
```
4 changes: 4 additions & 0 deletions GraphNeuralNetworks/src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using NNlib: scatter, gather
using ChainRulesCore
using Reexport
using MLUtils: zeros_like
using Graphs
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
using Graphs
using Graphs: Graphs

let's not import all symbols


using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T,
check_num_nodes, check_num_edges,
Expand Down Expand Up @@ -66,4 +67,7 @@ export GlobalPool,

include("deprecations.jl")

include("samplers.jl")
export NeighborLoader

end
155 changes: 155 additions & 0 deletions GraphNeuralNetworks/src/samplers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""
struct NeighborLoader
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
struct NeighborLoader
NeighborLoader(graph; num_neighbors, input_nodes, num_layers, [batch_size])


A data structure for sampling neighbors from a graph for training Graph Neural Networks (GNNs).
It supports multi-layer sampling of neighbors for a batch of input nodes, useful for mini-batch training.

# Fields
- `graph::GNNGraph`: The input graph containing nodes and edges, along with node features (from GraphNeuralNetworks.jl).
askorupka marked this conversation as resolved.
Show resolved Hide resolved
- `num_neighbors::Vector{Int}`: A vector specifying the number of neighbors to sample per node at each GNN layer.
- `input_nodes::Vector{Int}`: A vector containing the starting nodes for neighbor sampling.
- `num_layers::Int`: The number of layers for neighborhood expansion (how far to sample neighbors).
- `batch_size::Union{Int, Nothing}`: The size of the batch. If not specified, it defaults to the number of `input_nodes`.
- `neighbors_cache::Dict{Int, Vector{Int}}`: A cache to store sampled neighbors for each node, preventing redundant sampling.
askorupka marked this conversation as resolved.
Show resolved Hide resolved

askorupka marked this conversation as resolved.
Show resolved Hide resolved
# Usage
```julia
loader = NeighborLoader(graph; num_neighbors=[10, 5], input_nodes=[1, 2, 3], num_layers=2)
```
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here one could show that the loader can be iterated over, something like

for g in loader
  # do something
end

"""
struct NeighborLoader
graph::GNNGraph # The input GNNGraph (graph + features from GraphNeuralNetworks.jl)
num_neighbors::Vector{Int} # Number of neighbors to sample per node, for each layer
input_nodes::Vector{Int} # Set of input nodes (starting nodes for sampling)
num_layers::Int # Number of layers for neighborhood expansion
batch_size::Union{Int, Nothing} # Optional batch size, defaults to the length of input_nodes if not given
neighbors_cache::Dict{Int, Vector{Int}} # Cache neighbors to avoid recomputation
end

### `NeighborLoader` constructor
"""
NeighborLoader(graph::GNNGraph; num_neighbors::Vector{Int}, input_nodes::Vector{Int}, num_layers::Int, batch_size::Union{Int, Nothing}=nothing)

Creates a `NeighborLoader` to sample neighbors from the provided `graph` for the training.
The loader supports batching and multi-layer neighbor sampling.

# Arguments:
- `graph::GNNGraph`: The input graph with node features.
- `num_neighbors::Vector{Int}`: Number of neighbors to sample per node, per layer.
- `input_nodes::Vector{Int}`: Set of starting nodes for sampling.
- `num_layers::Int`: Number of layers to expand neighborhoods for sampling.
- `batch_size::Union{Int, Nothing}`: Optional batch size. If `nothing`, it defaults to the length of `input_nodes`.

# Returns:
A `NeighborLoader` object.
"""
askorupka marked this conversation as resolved.
Show resolved Hide resolved
function NeighborLoader(graph::GNNGraph; num_neighbors::Vector{Int}, input_nodes::Vector{Int}, num_layers::Int, batch_size::Union{Int, Nothing}=nothing)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_node could have default value = nothing. If not passed it should consider all nodes in the graph.

return NeighborLoader(graph, num_neighbors, input_nodes, num_layers, batch_size === nothing ? length(input_nodes) : batch_size, Dict{Int, Vector{Int}}())
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lines too long, break them

end

"""
get_neighbors(loader::NeighborLoader, node::Int)

Returns the neighbors of a given `node` in the graph from the `NeighborLoader`.
It first checks if the neighbors are cached; if not, it retrieves the neighbors from the graph and caches them for future use.

# Arguments:
- `loader::NeighborLoader`: The `NeighborLoader` instance.
- `node::Int`: The node whose neighbors you want to sample.

# Returns:
A vector of neighbor node indices.
"""
askorupka marked this conversation as resolved.
Show resolved Hide resolved
# Function to get cached neighbors or compute them
function get_neighbors(loader::NeighborLoader, node::Int)
if haskey(loader.neighbors_cache, node)
return loader.neighbors_cache[node]
else
neighbors = Graphs.neighbors(loader.graph, node, dir = :in) # Get neighbors from graph
loader.neighbors_cache[node] = neighbors
return neighbors
end
end

"""
sample_nbrs(loader::NeighborLoader, node::Int, layer::Int)

Samples a specified number of neighbors for the given `node` at a particular `layer` of the GNN.
The number of neighbors sampled is defined in `loader.num_neighbors`.

# Arguments:
- `loader::NeighborLoader`: The `NeighborLoader` instance.
- `node::Int`: The node to sample neighbors for.
- `layer::Int`: The current GNN layer (used to determine how many neighbors to sample).

# Returns:
A vector of sampled neighbor node indices.
"""
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't attach docstrings to internal methods.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you marked as resolve but the docstring is still here

# Function to sample neighbors for a given node at a specific layer
function sample_nbrs(loader::NeighborLoader, node::Int, layer::Int)
neighbors = get_neighbors(loader, node)
if isempty(neighbors)
return Int[]
else
num_samples = min(loader.num_neighbors[layer], length(neighbors)) # Limit to required samples for this layer
return rand(neighbors, num_samples) # Randomly sample neighbors
end
end

"""
Base.iterate(loader::NeighborLoader, state::Int=1)

Implements the iterator protocol for `NeighborLoader`, allowing mini-batch processing for neighbor sampling in GNNs.
Each call to `iterate` returns a mini-batch subgraph with sampled neighbors for a batch of input nodes,
expanding their neighborhoods for a specified number of layers.

# Arguments:
- `loader::NeighborLoader`: The `NeighborLoader` instance to sample neighbors from.
- `state::Int`: The current position in the input nodes for batching. Defaults to 1.

# Returns:
A tuple `(mini_batch_gnn, next_state)` where:
- `mini_batch_gnn::GNNGraph`: The subgraph induced by the sampled nodes and their neighbors for the current mini-batch.
- `next_state::Int`: The next state (index) for iterating through the input nodes. If the input nodes are exhausted, returns `nothing`.
"""
askorupka marked this conversation as resolved.
Show resolved Hide resolved
# Iterator protocol for NeighborLoader with lazy batch loading
function Base.iterate(loader::NeighborLoader, state=1)
if state > length(loader.input_nodes)
return nothing # End of iteration if batches are exhausted (state larger than amount of input nodes or current batch no >= batch number)
end

# Determine the size of the current batch
batch_size = min(loader.batch_size, length(loader.input_nodes) - state + 1) # Conditional in case there is not enough nodes to fill the last batch
batch_nodes = loader.input_nodes[state:state + batch_size - 1] # Each mini-batch uses different set of input nodes

# Set for tracking the subgraph nodes
subgraph_nodes = Set(batch_nodes)

for node in batch_nodes
# Initialize current layer of nodes (starting with the node itself)
sampled_neighbors = Set([node])

# For each GNN layer, sample the neighborhood
for layer in 1:loader.num_layers
new_neighbors = Set{Int}()
for n in sampled_neighbors
neighbors = sample_nbrs(loader, n, layer) # Sample neighbors of the node for this layer
new_neighbors = union(new_neighbors, neighbors) # Avoid duplicates in the neighbor set
end
sampled_neighbors = new_neighbors
subgraph_nodes = union(subgraph_nodes, sampled_neighbors) # Expand the subgraph with the new neighbors
end
end

# Collect subgraph nodes and their features
subgraph_node_list = collect(subgraph_nodes)

if isempty(subgraph_node_list)
return GNNGraph(), state + batch_size
end

mini_batch_gnn = Graphs.induced_subgraph(loader.graph, subgraph_node_list) # Create a subgraph of the nodes

# Continue iteration for the next batch
return mini_batch_gnn, state + batch_size
end
1 change: 1 addition & 0 deletions GraphNeuralNetworks/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ tests = [
"layers/temporalconv",
"layers/pool",
"examples/node_classification_cora",
"samplers"
]

!CUDA.functional() && @warn("CUDA unavailable, not testing GPU support")
Expand Down
125 changes: 125 additions & 0 deletions GraphNeuralNetworks/test/samplers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Helper function to create a simple graph with node features using GNNGraph
function create_test_graph()
source = [1, 2, 3, 4] # Define source nodes of edges
target = [2, 3, 4, 5] # Define target nodes of edges
node_features = rand(Float32, 5, 5) # Create random node features (5 features for 5 nodes)

return GNNGraph(source, target, ndata = node_features) # Create a GNNGraph with edges and features
end

# Tests for NeighborLoader structure and its functionalities
@testset "NeighborLoader tests" begin

# 1. Basic functionality: Check neighbor sampling and subgraph creation
@testset "Basic functionality" begin
g = create_test_graph()

# Define NeighborLoader with 2 neighbors per layer, 2 layers, batch size 2
loader = NeighborLoader(g; num_neighbors=[2, 2], input_nodes=[1, 2], num_layers=2, batch_size=2)

mini_batch_gnn, next_state = iterate(loader)

# Test if the mini-batch graph is not empty
@test !isempty(mini_batch_gnn.graph)

num_sampled_nodes = mini_batch_gnn.num_nodes
println("Number of nodes in mini-batch: ", num_sampled_nodes)

@test num_sampled_nodes == 2

# Test if there are edges in the subgraph
@test mini_batch_gnn.num_edges > 0
end

# 2. Edge case: Single node with no neighbors
@testset "Single node with no neighbors" begin
g = SimpleDiGraph(1) # A graph with a single node and no edges
node_features = rand(Float32, 5, 1)
graph = GNNGraph(g, ndata = node_features)

loader = NeighborLoader(graph; num_neighbors=[2], input_nodes=[1], num_layers=1)

mini_batch_gnn, next_state = iterate(loader)

# Test if the mini-batch graph contains only one node
@test size(mini_batch_gnn.x, 2) == 1
end

# 3. Edge case: A node with no outgoing edges (isolated node)
@testset "Node with no outgoing edges" begin
g = SimpleDiGraph(2) # Graph with 2 nodes, no edges
node_features = rand(Float32, 5, 2)
graph = GNNGraph(g, ndata = node_features)

loader = NeighborLoader(graph; num_neighbors=[1], input_nodes=[1, 2], num_layers=1)

mini_batch_gnn, next_state = iterate(loader)

# Test if the mini-batch graph contains the input nodes only (as no neighbors can be sampled)
@test size(mini_batch_gnn.x, 2) == 2 # Only two isolated nodes
end

# 4. Edge case: A fully connected graph
@testset "Fully connected graph" begin
g = SimpleDiGraph(3)
add_edge!(g, 1, 2)
add_edge!(g, 2, 3)
add_edge!(g, 3, 1)
node_features = rand(Float32, 5, 3)
graph = GNNGraph(g, ndata = node_features)

loader = NeighborLoader(graph; num_neighbors=[2, 2], input_nodes=[1], num_layers=2)

mini_batch_gnn, next_state = iterate(loader)

# Test if all nodes are included in the mini-batch since it's fully connected
@test size(mini_batch_gnn.x, 2) == 3 # All nodes should be included
end

# 5. Edge case: More layers than the number of neighbors
@testset "More layers than available neighbors" begin
g = SimpleDiGraph(3)
add_edge!(g, 1, 2)
add_edge!(g, 2, 3)
node_features = rand(Float32, 5, 3)
graph = GNNGraph(g, ndata = node_features)

# Test with 3 layers but only enough connections for 2 layers
loader = NeighborLoader(graph; num_neighbors=[1, 1, 1], input_nodes=[1], num_layers=3)

mini_batch_gnn, next_state = iterate(loader)

# Test if the mini-batch graph contains all available nodes
@test size(mini_batch_gnn.x, 2) == 1
end

# 6. Edge case: Large batch size greater than the number of input nodes
@testset "Large batch size" begin
g = create_test_graph()

# Define NeighborLoader with a larger batch size than input nodes
loader = NeighborLoader(g; num_neighbors=[2], input_nodes=[1, 2], num_layers=1, batch_size=10)

mini_batch_gnn, next_state = iterate(loader)

# Test if the mini-batch graph is not empty
@test !isempty(mini_batch_gnn.graph)

# Test if the correct number of nodes are sampled
@test size(mini_batch_gnn.x, 2) == length(unique([1, 2])) # Nodes [1, 2] are expected
end

# 7. Edge case: No neighbors sampled (num_neighbors = [0]) and 1 layer
@testset "No neighbors sampled" begin
g = create_test_graph()

# Define NeighborLoader with 0 neighbors per layer, 1 layer, batch size 2
loader = NeighborLoader(g; num_neighbors=[0], input_nodes=[1, 2], num_layers=1, batch_size=2)

mini_batch_gnn, next_state = iterate(loader)

# Test if the mini-batch graph contains only the input nodes
@test size(mini_batch_gnn.x, 2) == 2 # No neighbors should be sampled, only nodes 1 and 2 should be in the graph
end

end
Loading