Description
Describe the bug
In rl4co/models/nn/graph/gcn.py
, the GCNEncoder
class processes batched node embeddings without adjusting the edge_index for each graph in the batch. This leads to incorrect message passing in GCNConv
, where nodes from different graphs may be incorrectly connected or ignored, resulting in unintended behavior of the model.
Issue in Code:
# In rl4co/models/nn/graph/gcn.py, inside the GCNEncoder class
def forward(self, td):
# Transfer to embedding space
init_h = self.init_embedding(td)
bs, num_nodes, emb_dim = init_h.shape
# (bs*num_nodes, emb_dim)
update_node_feature = init_h.reshape(-1, emb_dim) # Flatten batch
# shape=(2, num_edges)
edge_index = self.edge_idx_fn(td, num_nodes) # Edge index for a single graph
for layer in self.gcn_layers[:-1]:
update_node_feature = layer(update_node_feature, edge_index)
# ...
Problem: edge_index is generated for a single graph and not adjusted for batching.
Consequence: When node features from multiple graphs are concatenated, GCNConv incorrectly processes nodes, only properly updating nodes of the first graph.
To Reproduce
Minimal Example to Reproduce the Bug:
import torch
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
from functools import lru_cache
# Function to generate full graph edge index (as in the original code)
@lru_cache(5)
def get_full_graph_edge_index(num_node: int, self_loop=False) -> torch.Tensor:
adj_matrix = torch.ones(num_node, num_node)
if not self_loop:
adj_matrix.fill_diagonal_(0)
edge_index = torch.permute(torch.nonzero(adj_matrix), (1, 0))
return edge_index
# Parameters
num_nodes_per_graph = 5
num_node_features = 3
batch_size = 2 # Number of graphs in the batch
# Create node features for two graphs
x1 = torch.randn(num_nodes_per_graph, num_node_features)
x2 = torch.randn(num_nodes_per_graph, num_node_features)
# Concatenate node features to simulate batching
x = torch.cat([x1, x2], dim=0) # Shape: [batch_size * num_nodes_per_graph, num_node_features]
# Generate edge_index using the original function
edge_index_single = get_full_graph_edge_index(num_nodes_per_graph, self_loop=False)
# Use the same edge_index for both graphs without adjusting
edge_index = edge_index_single # Edge indices from graph1, node indices 0 to 4
# Initialize GCNConv
conv = GCNConv(in_channels=num_node_features, out_channels=2, add_self_loops=False)
# Apply GCNConv without adjusting edge_index
out = conv(x, edge_index)
# Print the output
print("Output without adjusting edge_index:")
print(out)
# Now adjust edge_index for each graph to reflect batching
edge_indices = []
for i in range(batch_size):
offset = i * num_nodes_per_graph
adjusted_edge_index = edge_index_single + offset
edge_indices.append(adjusted_edge_index)
edge_index_adjusted = torch.cat(edge_indices, dim=1)
# Apply GCNConv with adjusted edge_index
out_adjusted = conv(x, edge_index_adjusted)
# Compare outputs for nodes belonging to each graph
print("\nDifference in outputs for nodes belonging to the first graph:")
print(out[:num_nodes_per_graph] - out_adjusted[:num_nodes_per_graph])
print("\nDifference in outputs for nodes belonging to the second graph:")
print(out[num_nodes_per_graph:] - out_adjusted[num_nodes_per_graph:])
Explanation:
Edge Index Generation: We use the get_full_graph_edge_index
function from the original code to generate the edge indices for a fully connected graph without self-loops.
Without Adjusted edge_index:
The edge_index
only references nodes 0 to 4.
Nodes 5 to 9 (from the second graph) are not properly connected and thus not processed correctly by GCNConv
.
With Adjusted edge_index:
By adding an offset to edge_index
for each graph in the batch, nodes are correctly connected within their own graphs.
The outputs for nodes in the second graph are now properly updated based on their graph structure.
Observation:
The differences in outputs for nodes belonging to the second graph indicate that without adjusting edge_index, those nodes were not processed correctly.
Additional context
The issue arises because in the GCNEncoder
, the edge_index
is generated once for a single graph and used for all graphs in the batch without adjustment. This causes incorrect node indexing when processing batched graphs, leading to unintended behavior.
Reason and Possible fixes
Reason:
The edge_index
is not adjusted for batched graphs, causing nodes from different graphs to be incorrectly connected or not processed correctly.
In batched graphs, node indices in edge_index need to be offset for each graph to reflect their positions in the concatenated node feature tensor.
Possible Fix:
Adjust the edge_index
for each graph in the batch by adding an offset based on the cumulative number of nodes. Here's a modification to the GCNEncoder class:
def forward(self, td):
init_h = self.init_embedding(td)
bs, num_nodes, emb_dim = init_h.shape
update_node_feature = init_h.reshape(-1, emb_dim)
# Original edge_index for a single graph
edge_index_single = self.edge_idx_fn(td, num_nodes)
# Adjust edge_index for batching
edge_indices = []
for i in range(bs):
offset = i * num_nodes
adjusted_edge_index = edge_index_single + offset
edge_indices.append(adjusted_edge_index)
edge_index = torch.cat(edge_indices, dim=1).to(td.device)
# Proceed with GCN layers using the adjusted edge_index
for layer in self.gcn_layers[:-1]:
update_node_feature = layer(update_node_feature, edge_index)
update_node_feature = F.relu(update_node_feature)
update_node_feature = F.dropout(
update_node_feature, training=self.training, p=self.dropout
)
# Last layer without activation and dropout
update_node_feature = self.gcn_layers[-1](update_node_feature, edge_index)
# De-batch the graph
update_node_feature = update_node_feature.view(bs, num_nodes, emb_dim)
# Residual connection
if self.residual:
update_node_feature = update_node_feature + init_h
return update_node_feature, init_h
Explanation:
Adjustment of edge_index:
For each graph in the batch, we add an offset to the edge_index so that node indices correctly map to their positions in the concatenated update_node_feature.
Benefit:
This ensures that nodes are only connected within their own graph, preventing unintended cross-graph connections.
GCNConv can now process each graph correctly within the batch.
Checklist
- [√] I have checked that there is no similar issue in the repo (required)
- [√] I have provided a minimal working example to reproduce the bug (required)