Vectorizing Graph Algorithms

Originally posted 2020-09-13

Tagged: software engineering, computer science, python, popular ⭐️

Obligatory disclaimer: all opinions are mine and not of my employer


Graph convolutional neural networks (GNNs) have become increasingly popular over the last few years. Deep neural networks have reinvented multiple fields (image recognition, language translation, audio recognition/generation), so the appeal of GNNs for graph data is self-evident.

GNNs present a computational challenge to most deep learning libraries, which are typically optimized for dense matrix multiplications and additions. While graph algorithms can typically be expressed as \(O(N^3)\) dense matrix multiplications on adjacency matrices, such representations are inefficient as most real-world graphs are sparse, only requiring \(O(N)\) operations in theory. On the other hand, sparse representations like adjacency lists have ragged shapes, which are not typical fare for deep learning libraries.

In this essay, I’ll demonstrate how to use NumPy APIs to implement computationally effective GNNs for sparse graphs, using PageRank as a stand-in for GNNs.

A simplified GNN: PageRank

In a GNN, each node/edge contains some information, and over multiple iterations, every node passes information to its neighbors, who update their state in response. There’s a lively subfield of machine learning exploring all sorts of variations on how this information is sent, received, and acted upon. (See this review for more)

For our investigation into graph representations, let’s stick to the simplest possible variation - one in which each node sends one number to its neighbors and each neighbor simply averages the information coming in.

As it turns out, this variation was invented several decades ago, and is most popularly known as PageRank.

In PageRank, you have each node (a website) distribute its importance score evenly to its neighbors (websites it links to). Iterate until the importance scores converge, and that gives you the importance of each website within the Internet. To handle disjoint subgraphs and to numerically stabilize the algorithm, a fraction of each node’s score is redistributed evenly to every other node in the graph.

As far as linear algebra goes, there’s a closed-form solution, but I’ll approximate the answer by executing the iterative procedure as described above.

I’ll note that the choice of PageRank as a benchmark favors methods with low overhead, because the amount of computation being done is so small (just a single number being passed along each edge in each iteration). In a typical GNN, a vector would typically be passed instead, allowing more work to be dispatched for the same overhead.

Implementing PageRank

Naive solution (Python adjacency list)

First, here’s the naive solution, which looks exactly like you would expect, given my verbal description.

def pagerank_naive(N, num_iterations=100, d=0.85):
    node_data = [{'score': 1.0/N, 'new_score': 0} for _ in range(N)]
    adj_list = adjacency_list(N)

    for _ in range(num_iterations):
        for from_id, to_ids in enumerate(adj_list):
            score_packet = node_data[from_id]['score'] / len(to_ids)
            for to_id in to_ids:
                node_data[to_id]['new_score'] += score_packet
        for data_dict in node_data:
            data_dict['score'] = data_dict['new_score'] * d + (1 - d) / N
            data_dict['new_score'] = 0
    return np.array([data_dict['score'] for data_dict in node_data])

Dense NumPy solution (NumPy adjacency matrix)

The adjacency matrix solution is probably the most elegant implementation of PageRank.

def pagerank_dense(N, num_iterations=100, d=0.85):
    adj_matrix = adjacency_matrix(N)
    transition_matrix = adj_matrix / np.sum(adj_matrix, axis=1, keepdims=True)
    transition_matrix = d * transition_matrix + (1 - d) / N

    score = np.ones([N], dtype=np.float32) / N
    for _ in range(num_iterations):
        score = score @ transition_matrix
    return score

This algorithm most closely matches the math involved. It’s also the fastest of the Python implementations as long as N is small (N < 1000). As mentioned in the introduction, it has one fatal flaw: because it is based on adjacency matrices, memory usage scales as \(O(N^2)\), and computations scale as \(O(N^3)\), regardless of the sparsity of the graph. Since most real graphs are sparse, this solution wastes a lot of time multiplying and adding zeros, and does not scale well. At N = 3,000, the original naive solution matches the dense array solution in overall speed.

Sparse solution (NumPy flattened adjacency list)

To scale to larger sparse graphs, we’ll have to use a graph representation that is more compact than an adjacency matrix. The original naive implementation used adjacency lists, so let’s optimize that.

For the most part, the conversion from Python to NumPy is fairly straightforward. The difficultly comes from converting the loops over each node’s adjacency list, which requires rethinking our data representation. An adjacency list is a list of lists of varying lengths, (also called a “ragged” array). Since vectorized operation are most efficient with rectangular shapes, we should find some way to coerce this data into rectangular form. (SciPy and other libraries have native support for sparse matrices, but I find it more understandable to manually implement sparsity.)

The simplest way to get rectangular adjacency lists is to pad each adjacency list to the length of the longest such list using some sentinel value.

adj_list = [[0], [1, 2, 3], [4, 5], [6], [7, 8, 9]]
# padded version, using -1 as a sentinel value
[[0, -1, -1], [1, 2, 3], [4, 5, -1], [6, -1, -1], [7, 8, 9]]

However, this doesn’t work very well when there’s a long tail of highly-connected nodes, which is common in real-world graphs. Every node’s adjacency list would be excessively padded, negating any efficiency gains.

The next simplest way to do this is to just flatten the adjacency lists into one long array:

adj_list = [[0], [1, 2, 3], [4, 5], [6], [7, 8, 9]]
# flattened version
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

This representation loses information about where each sublist begins and ends. To salvage the situation, we’ll store the originating list index in a separate array

adj_list = [[0], [1, 2, 3], [4, 5], [6], [7, 8, 9]]
# becomes...
from_nodes = [0, 1, 1, 1, 2, 2, 3, 4, 4, 4]
to_nodes   = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

These two arrays are essentially pointers in vectorized form. With a flattened adjacency list representation, we can vectorize the core loop.

for node, out_nodes in enumerate(adj_list):
    score_packet = g.nodes[node]['score'] / len(out_nodes)
    for out_node in out_nodes:
        g.nodes[out_node]['new_score'] += score_packet

# becomes...
# NumPy dialect
score_packets = score[from_nodes]
new_score = np.zeros_like(score)
np.add.at(new_score, to_nodes, score_packets)
# TensorFlow dialect
score_packets = tf.gather(score, from_nodes)
new_score = tf.math.unsorted_segment_sum(score_packets, to_nodes, N)

This representation has other benefits. Pointer latency is the bane of computing (a topic I explore here), and yet pointers are a necessary layer of indirection in graphs. Modern computers try to ameliorate the situation with a hierarchy of caches, where each successive layer has lower latency, but also lower capacity. By reimplementing 64bit pointers as a compact vector of 16bit integer offsets into a contiguous array, our data fits better into these caches and benefits from the faster data latency.

Performance comparison

I ran all three algorithms on sparse graphs of various sizes, all having roughly 3 times as many edges as vertices.

pagerank implementation speed comparison

This chart demonstrates a number of interesting things:

  • The NumPy implementations have a hockey-stick look, due to the cost of the Python overhead at low N. The naive python implementation shows no overhead, but you can also argue that it shows only overhead ;).
  • Both naive and vectorized sparse implementations demonstrate the expected linear scaling.
  • My vectorized sparse implementation performs up to 60x faster than the original naive implementation.
  • The dense implementation starts off much faster due to its extreme simplicity (minimizing any Python overhead), but loses to the sparse implementation at around N = 300, due to its cubic scaling.

JIT and C implementations

For fun, I also implemented the flattened adjacency list solution in C, to estimate the performance ceiling. I was surprised to find that despite having optimized Python performance by 60x, another 5x in potential optimizations existed.

I suspected that the gap was due to residual Python dispatch overhead. A potential solution is JIT compilation, which traces the dispatch patterns, eliminating that overhead from the final execution. I tried using Numba, TensorFlow, and JAX to close the gap.

pagerank implementation speed comparison

Or, expressed in relative speed to the C implementation:

pagerank implementation speed comparison

(I’m not familiar with PyTorch; if you’d like to see PyTorch on this graph, send me a pull request).

Numba is the winner in this particular performance benchmark. TensorFlow and JAX also showed modest performance gains of 25%, and their overlap is expected as they’re both being compiled to XLA.

The Numba benchmark comes with some caveats. Numba didn’t actually support compiling the sparse index / gather APIs I’d used, so I had to rewrite the core loop.

# TensorFlow version
score_packets = score[from_nodes]
score_packets = tf.gather(score, from_nodes)
score = tf.math.unsorted_segment_sum(score_packets, to_nodes, N)

# Numba rewrite
new_score = np.zeros_like(score)
for i in range(from_nodes.shape[0]):
    new_score[to_nodes[i]] += score[from_nodes[i]]
score = new_score

For comparison, here’s my C implementation.

# C code
for (int e = 0; e < num_edges; e++) {
    new_scores[dest[e]] += scores[src[e]];
}

I’d consider the Numba implementation akin to inlining some C code into Python. My guess is that the XLA compiler was unable to rewrite this code, leading to TF and JAX’s underperformance.

Conclusion

I’ve shown how to use a flatten our adjacency list representation to get performant GNN code. This representation may seem complicated, but I’ve found it to be quite versatile in performantly expressing a wide array of GNN architectures.

While Numba can bring us to within 1.5x performance of handwritten C, it requires essentially writing some inline C and wrestling with obfuscated compilation errors, so I don’t consider this to be a free win.

TensorFlow/JAX can get within 4x performance of handwritten C while using idiomatic NumPy APIs. Given that TensorFlow and JAX also come with automatic differentiation (table stakes for doing machine learning!) and a broader ecosystem of practitioners, I’m pretty happy using those libraries.

You can find the full implementation details on GitHub.

Thanks to Alexey Radul for commenting on drafts of this essay.