-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add NeighborLoader #497
base: master
Are you sure you want to change the base?
Changes from 23 commits
349d99f
07db4d0
25945c7
506d4c7
991bf61
c25bc1e
fde10bb
3656691
9997fab
acf209c
0c4a653
2035c5e
ebebce9
abf31cd
bcdfa5e
970d297
b4c1ad7
5e7544c
c9d412b
2d7bd0b
65aa564
61c5e39
aec5574
e675086
62f5d87
3ed22bf
e4dc977
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
```@meta | ||
CurrentModule = GraphNeuralNetworks | ||
``` | ||
|
||
# Samplers | ||
|
||
|
||
## Docs | ||
|
||
```@autodocs | ||
Modules = [GraphNeuralNetworks] | ||
Pages = ["samplers.jl"] | ||
Private = false | ||
``` |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -10,6 +10,7 @@ using NNlib: scatter, gather | |||||
using ChainRulesCore | ||||||
using Reexport | ||||||
using MLUtils: zeros_like | ||||||
using Graphs | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
let's not import all symbols |
||||||
|
||||||
using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T, | ||||||
check_num_nodes, check_num_edges, | ||||||
|
@@ -66,4 +67,7 @@ export GlobalPool, | |||||
|
||||||
include("deprecations.jl") | ||||||
|
||||||
include("samplers.jl") | ||||||
export NeighborLoader | ||||||
|
||||||
end |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,155 @@ | ||||||
""" | ||||||
struct NeighborLoader | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
A data structure for sampling neighbors from a graph for training Graph Neural Networks (GNNs). | ||||||
It supports multi-layer sampling of neighbors for a batch of input nodes, useful for mini-batch training. | ||||||
|
||||||
# Fields | ||||||
- `graph::GNNGraph`: The input graph containing nodes and edges, along with node features (from GraphNeuralNetworks.jl). | ||||||
askorupka marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
- `num_neighbors::Vector{Int}`: A vector specifying the number of neighbors to sample per node at each GNN layer. | ||||||
- `input_nodes::Vector{Int}`: A vector containing the starting nodes for neighbor sampling. | ||||||
- `num_layers::Int`: The number of layers for neighborhood expansion (how far to sample neighbors). | ||||||
- `batch_size::Union{Int, Nothing}`: The size of the batch. If not specified, it defaults to the number of `input_nodes`. | ||||||
- `neighbors_cache::Dict{Int, Vector{Int}}`: A cache to store sampled neighbors for each node, preventing redundant sampling. | ||||||
askorupka marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
askorupka marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
# Usage | ||||||
```julia | ||||||
loader = NeighborLoader(graph; num_neighbors=[10, 5], input_nodes=[1, 2, 3], num_layers=2) | ||||||
``` | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here one could show that the loader can be iterated over, something like for g in loader
# do something
end |
||||||
""" | ||||||
struct NeighborLoader | ||||||
graph::GNNGraph # The input GNNGraph (graph + features from GraphNeuralNetworks.jl) | ||||||
num_neighbors::Vector{Int} # Number of neighbors to sample per node, for each layer | ||||||
input_nodes::Vector{Int} # Set of input nodes (starting nodes for sampling) | ||||||
num_layers::Int # Number of layers for neighborhood expansion | ||||||
batch_size::Union{Int, Nothing} # Optional batch size, defaults to the length of input_nodes if not given | ||||||
neighbors_cache::Dict{Int, Vector{Int}} # Cache neighbors to avoid recomputation | ||||||
end | ||||||
|
||||||
### `NeighborLoader` constructor | ||||||
""" | ||||||
NeighborLoader(graph::GNNGraph; num_neighbors::Vector{Int}, input_nodes::Vector{Int}, num_layers::Int, batch_size::Union{Int, Nothing}=nothing) | ||||||
|
||||||
Creates a `NeighborLoader` to sample neighbors from the provided `graph` for the training. | ||||||
The loader supports batching and multi-layer neighbor sampling. | ||||||
|
||||||
# Arguments: | ||||||
- `graph::GNNGraph`: The input graph with node features. | ||||||
- `num_neighbors::Vector{Int}`: Number of neighbors to sample per node, per layer. | ||||||
- `input_nodes::Vector{Int}`: Set of starting nodes for sampling. | ||||||
- `num_layers::Int`: Number of layers to expand neighborhoods for sampling. | ||||||
- `batch_size::Union{Int, Nothing}`: Optional batch size. If `nothing`, it defaults to the length of `input_nodes`. | ||||||
|
||||||
# Returns: | ||||||
A `NeighborLoader` object. | ||||||
""" | ||||||
askorupka marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
function NeighborLoader(graph::GNNGraph; num_neighbors::Vector{Int}, input_nodes::Vector{Int}, num_layers::Int, batch_size::Union{Int, Nothing}=nothing) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
return NeighborLoader(graph, num_neighbors, input_nodes, num_layers, batch_size === nothing ? length(input_nodes) : batch_size, Dict{Int, Vector{Int}}()) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lines too long, break them |
||||||
end | ||||||
|
||||||
""" | ||||||
get_neighbors(loader::NeighborLoader, node::Int) | ||||||
|
||||||
Returns the neighbors of a given `node` in the graph from the `NeighborLoader`. | ||||||
It first checks if the neighbors are cached; if not, it retrieves the neighbors from the graph and caches them for future use. | ||||||
|
||||||
# Arguments: | ||||||
- `loader::NeighborLoader`: The `NeighborLoader` instance. | ||||||
- `node::Int`: The node whose neighbors you want to sample. | ||||||
|
||||||
# Returns: | ||||||
A vector of neighbor node indices. | ||||||
""" | ||||||
askorupka marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
# Function to get cached neighbors or compute them | ||||||
function get_neighbors(loader::NeighborLoader, node::Int) | ||||||
if haskey(loader.neighbors_cache, node) | ||||||
return loader.neighbors_cache[node] | ||||||
else | ||||||
neighbors = Graphs.neighbors(loader.graph, node, dir = :in) # Get neighbors from graph | ||||||
loader.neighbors_cache[node] = neighbors | ||||||
return neighbors | ||||||
end | ||||||
end | ||||||
|
||||||
""" | ||||||
sample_nbrs(loader::NeighborLoader, node::Int, layer::Int) | ||||||
|
||||||
Samples a specified number of neighbors for the given `node` at a particular `layer` of the GNN. | ||||||
The number of neighbors sampled is defined in `loader.num_neighbors`. | ||||||
|
||||||
# Arguments: | ||||||
- `loader::NeighborLoader`: The `NeighborLoader` instance. | ||||||
- `node::Int`: The node to sample neighbors for. | ||||||
- `layer::Int`: The current GNN layer (used to determine how many neighbors to sample). | ||||||
|
||||||
# Returns: | ||||||
A vector of sampled neighbor node indices. | ||||||
""" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't attach docstrings to internal methods. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you marked as resolve but the docstring is still here |
||||||
# Function to sample neighbors for a given node at a specific layer | ||||||
function sample_nbrs(loader::NeighborLoader, node::Int, layer::Int) | ||||||
neighbors = get_neighbors(loader, node) | ||||||
if isempty(neighbors) | ||||||
return Int[] | ||||||
else | ||||||
num_samples = min(loader.num_neighbors[layer], length(neighbors)) # Limit to required samples for this layer | ||||||
return rand(neighbors, num_samples) # Randomly sample neighbors | ||||||
end | ||||||
end | ||||||
|
||||||
""" | ||||||
Base.iterate(loader::NeighborLoader, state::Int=1) | ||||||
|
||||||
Implements the iterator protocol for `NeighborLoader`, allowing mini-batch processing for neighbor sampling in GNNs. | ||||||
Each call to `iterate` returns a mini-batch subgraph with sampled neighbors for a batch of input nodes, | ||||||
expanding their neighborhoods for a specified number of layers. | ||||||
|
||||||
# Arguments: | ||||||
- `loader::NeighborLoader`: The `NeighborLoader` instance to sample neighbors from. | ||||||
- `state::Int`: The current position in the input nodes for batching. Defaults to 1. | ||||||
|
||||||
# Returns: | ||||||
A tuple `(mini_batch_gnn, next_state)` where: | ||||||
- `mini_batch_gnn::GNNGraph`: The subgraph induced by the sampled nodes and their neighbors for the current mini-batch. | ||||||
- `next_state::Int`: The next state (index) for iterating through the input nodes. If the input nodes are exhausted, returns `nothing`. | ||||||
""" | ||||||
askorupka marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
# Iterator protocol for NeighborLoader with lazy batch loading | ||||||
function Base.iterate(loader::NeighborLoader, state=1) | ||||||
if state > length(loader.input_nodes) | ||||||
return nothing # End of iteration if batches are exhausted (state larger than amount of input nodes or current batch no >= batch number) | ||||||
end | ||||||
|
||||||
# Determine the size of the current batch | ||||||
batch_size = min(loader.batch_size, length(loader.input_nodes) - state + 1) # Conditional in case there is not enough nodes to fill the last batch | ||||||
batch_nodes = loader.input_nodes[state:state + batch_size - 1] # Each mini-batch uses different set of input nodes | ||||||
|
||||||
# Set for tracking the subgraph nodes | ||||||
subgraph_nodes = Set(batch_nodes) | ||||||
|
||||||
for node in batch_nodes | ||||||
# Initialize current layer of nodes (starting with the node itself) | ||||||
sampled_neighbors = Set([node]) | ||||||
|
||||||
# For each GNN layer, sample the neighborhood | ||||||
for layer in 1:loader.num_layers | ||||||
new_neighbors = Set{Int}() | ||||||
for n in sampled_neighbors | ||||||
neighbors = sample_nbrs(loader, n, layer) # Sample neighbors of the node for this layer | ||||||
new_neighbors = union(new_neighbors, neighbors) # Avoid duplicates in the neighbor set | ||||||
end | ||||||
sampled_neighbors = new_neighbors | ||||||
subgraph_nodes = union(subgraph_nodes, sampled_neighbors) # Expand the subgraph with the new neighbors | ||||||
end | ||||||
end | ||||||
|
||||||
# Collect subgraph nodes and their features | ||||||
subgraph_node_list = collect(subgraph_nodes) | ||||||
|
||||||
if isempty(subgraph_node_list) | ||||||
return GNNGraph(), state + batch_size | ||||||
end | ||||||
|
||||||
mini_batch_gnn = Graphs.induced_subgraph(loader.graph, subgraph_node_list) # Create a subgraph of the nodes | ||||||
|
||||||
# Continue iteration for the next batch | ||||||
return mini_batch_gnn, state + batch_size | ||||||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# Helper function to create a simple graph with node features using GNNGraph | ||
function create_test_graph() | ||
source = [1, 2, 3, 4] # Define source nodes of edges | ||
target = [2, 3, 4, 5] # Define target nodes of edges | ||
node_features = rand(Float32, 5, 5) # Create random node features (5 features for 5 nodes) | ||
|
||
return GNNGraph(source, target, ndata = node_features) # Create a GNNGraph with edges and features | ||
end | ||
|
||
# Tests for NeighborLoader structure and its functionalities | ||
@testset "NeighborLoader tests" begin | ||
|
||
# 1. Basic functionality: Check neighbor sampling and subgraph creation | ||
@testset "Basic functionality" begin | ||
g = create_test_graph() | ||
|
||
# Define NeighborLoader with 2 neighbors per layer, 2 layers, batch size 2 | ||
loader = NeighborLoader(g; num_neighbors=[2, 2], input_nodes=[1, 2], num_layers=2, batch_size=2) | ||
|
||
mini_batch_gnn, next_state = iterate(loader) | ||
|
||
# Test if the mini-batch graph is not empty | ||
@test !isempty(mini_batch_gnn.graph) | ||
|
||
num_sampled_nodes = mini_batch_gnn.num_nodes | ||
println("Number of nodes in mini-batch: ", num_sampled_nodes) | ||
|
||
@test num_sampled_nodes == 2 | ||
|
||
# Test if there are edges in the subgraph | ||
@test mini_batch_gnn.num_edges > 0 | ||
end | ||
|
||
# 2. Edge case: Single node with no neighbors | ||
@testset "Single node with no neighbors" begin | ||
g = SimpleDiGraph(1) # A graph with a single node and no edges | ||
node_features = rand(Float32, 5, 1) | ||
graph = GNNGraph(g, ndata = node_features) | ||
|
||
loader = NeighborLoader(graph; num_neighbors=[2], input_nodes=[1], num_layers=1) | ||
|
||
mini_batch_gnn, next_state = iterate(loader) | ||
|
||
# Test if the mini-batch graph contains only one node | ||
@test size(mini_batch_gnn.x, 2) == 1 | ||
end | ||
|
||
# 3. Edge case: A node with no outgoing edges (isolated node) | ||
@testset "Node with no outgoing edges" begin | ||
g = SimpleDiGraph(2) # Graph with 2 nodes, no edges | ||
node_features = rand(Float32, 5, 2) | ||
graph = GNNGraph(g, ndata = node_features) | ||
|
||
loader = NeighborLoader(graph; num_neighbors=[1], input_nodes=[1, 2], num_layers=1) | ||
|
||
mini_batch_gnn, next_state = iterate(loader) | ||
|
||
# Test if the mini-batch graph contains the input nodes only (as no neighbors can be sampled) | ||
@test size(mini_batch_gnn.x, 2) == 2 # Only two isolated nodes | ||
end | ||
|
||
# 4. Edge case: A fully connected graph | ||
@testset "Fully connected graph" begin | ||
g = SimpleDiGraph(3) | ||
add_edge!(g, 1, 2) | ||
add_edge!(g, 2, 3) | ||
add_edge!(g, 3, 1) | ||
node_features = rand(Float32, 5, 3) | ||
graph = GNNGraph(g, ndata = node_features) | ||
|
||
loader = NeighborLoader(graph; num_neighbors=[2, 2], input_nodes=[1], num_layers=2) | ||
|
||
mini_batch_gnn, next_state = iterate(loader) | ||
|
||
# Test if all nodes are included in the mini-batch since it's fully connected | ||
@test size(mini_batch_gnn.x, 2) == 3 # All nodes should be included | ||
end | ||
|
||
# 5. Edge case: More layers than the number of neighbors | ||
@testset "More layers than available neighbors" begin | ||
g = SimpleDiGraph(3) | ||
add_edge!(g, 1, 2) | ||
add_edge!(g, 2, 3) | ||
node_features = rand(Float32, 5, 3) | ||
graph = GNNGraph(g, ndata = node_features) | ||
|
||
# Test with 3 layers but only enough connections for 2 layers | ||
loader = NeighborLoader(graph; num_neighbors=[1, 1, 1], input_nodes=[1], num_layers=3) | ||
|
||
mini_batch_gnn, next_state = iterate(loader) | ||
|
||
# Test if the mini-batch graph contains all available nodes | ||
@test size(mini_batch_gnn.x, 2) == 1 | ||
end | ||
|
||
# 6. Edge case: Large batch size greater than the number of input nodes | ||
@testset "Large batch size" begin | ||
g = create_test_graph() | ||
|
||
# Define NeighborLoader with a larger batch size than input nodes | ||
loader = NeighborLoader(g; num_neighbors=[2], input_nodes=[1, 2], num_layers=1, batch_size=10) | ||
|
||
mini_batch_gnn, next_state = iterate(loader) | ||
|
||
# Test if the mini-batch graph is not empty | ||
@test !isempty(mini_batch_gnn.graph) | ||
|
||
# Test if the correct number of nodes are sampled | ||
@test size(mini_batch_gnn.x, 2) == length(unique([1, 2])) # Nodes [1, 2] are expected | ||
end | ||
|
||
# 7. Edge case: No neighbors sampled (num_neighbors = [0]) and 1 layer | ||
@testset "No neighbors sampled" begin | ||
g = create_test_graph() | ||
|
||
# Define NeighborLoader with 0 neighbors per layer, 1 layer, batch size 2 | ||
loader = NeighborLoader(g; num_neighbors=[0], input_nodes=[1, 2], num_layers=1, batch_size=2) | ||
|
||
mini_batch_gnn, next_state = iterate(loader) | ||
|
||
# Test if the mini-batch graph contains only the input nodes | ||
@test size(mini_batch_gnn.x, 2) == 2 # No neighbors should be sampled, only nodes 1 and 2 should be in the graph | ||
end | ||
|
||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should add the corresponding compat bound