Message Passing in GNNs

news
code
analysis
Author

Bulent Soykan

Published

February 16, 2025

Message Passing in Graph Neural Networks (GNNs) as Matrix Multiplication

Graph Neural Networks (GNNs) use message passing to aggregate information from neighboring nodes. This process can be understood as simple matrix multiplication using the adjacency matrix and node feature matrix.


1. Node Features as a Matrix

Each node in a graph has a feature vector. If we have ( N ) nodes and each node has ( F ) features, we represent node features as a matrix ( X ):

[ X = \[\begin{bmatrix} x_1^T \\ x_2^T \\ x_3^T \\ \vdots \\ x_N^T \end{bmatrix}\]

]

where ( x_i ) is the feature vector of node ( i ), and ( X ) has shape ( (N F) ).

Example (3 Nodes, 2 Features each)

X = [
    [1, 2],
    [3, 4],
    [5, 6]
]
  • Node 1 has features [1,2]
  • Node 2 has features [3,4]
  • Node 3 has features [5,6]

2. Adjacency Matrix as a Propagation Rule

The adjacency matrix ( A ) determines which nodes pass messages to each other.

[ A = \[\begin{bmatrix} 0 & 1 & 1 \\ 1 & 0 & 1 \\ 1 & 1 & 0 \end{bmatrix}\]

]

  • Node 1 connects to Nodes 2 and 3.
  • Node 2 connects to Nodes 1 and 3.
  • Node 3 connects to Nodes 1 and 2.

3. Message Passing as Matrix Multiplication

Message passing means each node aggregates features from its neighbors. Mathematically, this is:

[ AX ]

which distributes each node’s feature vector to its neighbors.

Computing ( AX )

Multiply ( A ) (shape: ( N N )) by ( X ) (shape: ( N F )):

[ AX = \[\begin{bmatrix} 0 & 1 & 1 \\ 1 & 0 & 1 \\ 1 & 1 & 0 \end{bmatrix} \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix}\]

]

Computing row-wise:

  • Node 1’s new feature:
    [ (0 ) + (1 ) + (1 ) = [8,10] ]
  • Node 2’s new feature:
    [ (1 ) + (0 ) + (1 ) = [6,8] ]
  • Node 3’s new feature:
    [ (1 ) + (1 ) + (0 ) = [4,6] ]

So, the new feature matrix is:

[ AX = \[\begin{bmatrix} 8 & 10 \\ 6 & 8 \\ 4 & 6 \end{bmatrix}\]

]


4. Normalization (Avoiding Scale Explosion)

Raw adjacency multiplication can lead to large feature values, so we normalize it using degree normalization:

[ D^{-1} A X ]

where ( D ) is the degree matrix (diagonal matrix with node degrees):

[ D = \[\begin{bmatrix} 2 & 0 & 0 \\ 0 & 2 & 0 \\ 0 & 0 & 2 \end{bmatrix}\]

]

So,

[ D^{-1} = \[\begin{bmatrix} \frac{1}{2} & 0 & 0 \\ 0 & \frac{1}{2} & 0 \\ 0 & 0 & \frac{1}{2} \end{bmatrix}\]

]

Multiplying:

[ D^{-1}AX = \[\begin{bmatrix} \frac{1}{2} & 0 & 0 \\ 0 & \frac{1}{2} & 0 \\ 0 & 0 & \frac{1}{2} \end{bmatrix} \begin{bmatrix} 8 & 10 \\ 6 & 8 \\ 4 & 6 \end{bmatrix}\] = \[\begin{bmatrix} 4 & 5 \\ 3 & 4 \\ 2 & 3 \end{bmatrix}\]

]


Summary

  • Message passing spreads node features across the graph.
  • It is done using adjacency matrix multiplication: ( AX ).
  • Normalization (( D^{-1} A X )) prevents feature explosion.

Python Implementation

import numpy as np
import torch
import torch.nn as nn

# Numpy Implementation
X = np.array([[1, 2], [3, 4], [5, 6]])
A = np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]])
AX = A @ X
D = np.diag(A.sum(axis=1))
D_inv = np.linalg.inv(D)
normalized_AX = D_inv @ AX
print("AX (Numpy):\n", AX)
print("D^{-1}AX (Numpy):\n", normalized_AX)

# PyTorch Implementation with Learnable Weights
class GNNLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(GNNLayer, self).__init__()
        self.W = nn.Linear(in_features, out_features, bias=False)
    
    def forward(self, X, A):
        D = torch.diag(torch.sum(A, dim=1))
        D_inv = torch.inverse(D)
        AX = torch.matmul(A, X)
        normalized_AX = torch.matmul(D_inv, AX)
        return self.W(normalized_AX)

X_torch = torch.tensor(X, dtype=torch.float32)
A_torch = torch.tensor(A, dtype=torch.float32)

gnn_layer = GNNLayer(in_features=2, out_features=2)
output = gnn_layer(X_torch, A_torch)
print("GNN Output with Learnable Weights:", output)

This implementation now includes a learnable weight matrix in PyTorch, using nn.Linear for feature transformation.