Message Passing in Graph Neural Networks (GNNs) as Matrix Multiplication
Graph Neural Networks (GNNs) use message passing to aggregate information from neighboring nodes. This process can be understood as simple matrix multiplication using the adjacency matrix and node feature matrix.
1. Node Features as a Matrix
Each node in a graph has a feature vector. If we have ( N ) nodes and each node has ( F ) features, we represent node features as a matrix ( X ):
[ X = \[\begin{bmatrix} x_1^T \\ x_2^T \\ x_3^T \\ \vdots \\ x_N^T \end{bmatrix}\]]
where ( x_i ) is the feature vector of node ( i ), and ( X ) has shape ( (N F) ).
Example (3 Nodes, 2 Features each)
X = [
[1, 2],
[3, 4],
[5, 6]
]
- Node 1 has features
[1,2]
- Node 2 has features
[3,4]
- Node 3 has features
[5,6]
2. Adjacency Matrix as a Propagation Rule
The adjacency matrix ( A ) determines which nodes pass messages to each other.
[ A = \[\begin{bmatrix} 0 & 1 & 1 \\ 1 & 0 & 1 \\ 1 & 1 & 0 \end{bmatrix}\]]
- Node 1 connects to Nodes 2 and 3.
- Node 2 connects to Nodes 1 and 3.
- Node 3 connects to Nodes 1 and 2.
3. Message Passing as Matrix Multiplication
Message passing means each node aggregates features from its neighbors. Mathematically, this is:
[ AX ]
which distributes each node’s feature vector to its neighbors.
Computing ( AX )
Multiply ( A ) (shape: ( N N )) by ( X ) (shape: ( N F )):
[ AX = \[\begin{bmatrix} 0 & 1 & 1 \\ 1 & 0 & 1 \\ 1 & 1 & 0 \end{bmatrix} \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix}\]]
Computing row-wise:
- Node 1’s new feature:
[ (0 ) + (1 ) + (1 ) = [8,10] ] - Node 2’s new feature:
[ (1 ) + (0 ) + (1 ) = [6,8] ] - Node 3’s new feature:
[ (1 ) + (1 ) + (0 ) = [4,6] ]
So, the new feature matrix is:
[ AX = \[\begin{bmatrix} 8 & 10 \\ 6 & 8 \\ 4 & 6 \end{bmatrix}\]]
4. Normalization (Avoiding Scale Explosion)
Raw adjacency multiplication can lead to large feature values, so we normalize it using degree normalization:
[ D^{-1} A X ]
where ( D ) is the degree matrix (diagonal matrix with node degrees):
[ D = \[\begin{bmatrix} 2 & 0 & 0 \\ 0 & 2 & 0 \\ 0 & 0 & 2 \end{bmatrix}\]]
So,
[ D^{-1} = \[\begin{bmatrix} \frac{1}{2} & 0 & 0 \\ 0 & \frac{1}{2} & 0 \\ 0 & 0 & \frac{1}{2} \end{bmatrix}\]]
Multiplying:
[ D^{-1}AX = \[\begin{bmatrix} \frac{1}{2} & 0 & 0 \\ 0 & \frac{1}{2} & 0 \\ 0 & 0 & \frac{1}{2} \end{bmatrix} \begin{bmatrix} 8 & 10 \\ 6 & 8 \\ 4 & 6 \end{bmatrix}\] = \[\begin{bmatrix} 4 & 5 \\ 3 & 4 \\ 2 & 3 \end{bmatrix}\]]
Summary
- Message passing spreads node features across the graph.
- It is done using adjacency matrix multiplication: ( AX ).
- Normalization (( D^{-1} A X )) prevents feature explosion.
Python Implementation
import numpy as np
import torch
import torch.nn as nn
# Numpy Implementation
= np.array([[1, 2], [3, 4], [5, 6]])
X = np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]])
A = A @ X
AX = np.diag(A.sum(axis=1))
D = np.linalg.inv(D)
D_inv = D_inv @ AX
normalized_AX print("AX (Numpy):\n", AX)
print("D^{-1}AX (Numpy):\n", normalized_AX)
# PyTorch Implementation with Learnable Weights
class GNNLayer(nn.Module):
def __init__(self, in_features, out_features):
super(GNNLayer, self).__init__()
self.W = nn.Linear(in_features, out_features, bias=False)
def forward(self, X, A):
= torch.diag(torch.sum(A, dim=1))
D = torch.inverse(D)
D_inv = torch.matmul(A, X)
AX = torch.matmul(D_inv, AX)
normalized_AX return self.W(normalized_AX)
= torch.tensor(X, dtype=torch.float32)
X_torch = torch.tensor(A, dtype=torch.float32)
A_torch
= GNNLayer(in_features=2, out_features=2)
gnn_layer = gnn_layer(X_torch, A_torch)
output print("GNN Output with Learnable Weights:", output)
This implementation now includes a learnable weight matrix in PyTorch, using nn.Linear
for feature transformation.