Data-oriented Programming in Python
Originally posted 2020-09-13
Obligatory disclaimer: all opinions are mine and not of my employer
Many users of Python deprioritize performance in favor of soft benefits like ergonomics, business value, and simplicity. Users who prioritize performance typically end up on faster compiled languages like C++ or Java.
One group of users is left behind, though. The scientific computing community has lots of raw data they need to process, and would very much like performance. Yet, they struggle to move away from Python, because of network effects, and because Python’s beginner-friendliness is appealing to scientists for whom programming is not a first language. So, how can Python users achieve some fraction of the performance that their C++ and Java friends enjoy?
In practice, scientific computing users rely on the NumPy family of libraries e.g. NumPy, SciPy, TensorFlow, PyTorch, CuPy, JAX, etc.. The sheer proliferation of these libraries suggests that the NumPy model is getting something right. In this essay, I’ll talk about what makes NumPy so effective, and where the next generation of Python numerical computing libraries (e.g. TensorFlow, PyTorch, JAX) seems to be headed.
Data good, pointers bad
A pesky fact of computing is that computers can compute far faster than we can deliver data to compute on. In particular, data transfer latency is the Achille’s heel of data devices (both RAM and storage). Manufacturers disguise this weakness by emphasizing improvements in data transfer throughput, but latency continues to stagnate. Ultimately, this means that any chained data access patterns, where one data retrieval must be completed before the next may proceed, are the worst case for computers.
These worst-case chained data access patterns are unfortunately quite common – so common that they have a name you may be familiar with: a pointer.
Pointers have always been slow. In the ’80s and ’90s, our hard drives were essentially optimized record players, with a read head riding on top of a spinning platter. These hard drives had physical limitations: The disk could only spin so fast without shattering, and the read head was also mechanical, limiting its movement speed. Disk seeks were slow, and the programs that were most severely affected were databases. Some ways that databases dealt with these physical limitations are:
- Instead of using binary trees (requiring \(\log_2 N\) disk seeks), B-trees with a much higher branching factor \(k\) were used, only requiring \(\log_k N\) disk seeks.
- Indices were used to query data without having to read the full contents of each row.
- Vertically-oriented databases optimized for read-heavy workloads (e.g. summary statistics over one field, across entire datasets), by reorganizing from arrays of structs to structs of arrays. This maximized effective disk throughput, since no extraneous data was loaded.
Today, compute speed is roughly \(10^5 - 10^6\) times faster than in 1990. Today, RAM is roughly \(10^5\) times faster than HDDs from 1990. I was amused and unsurprised to find that Raymond Hettinger’s excellent talk on the evolution of Python’s in-memory
dict implementation plays out like a brief history of early database design. Time, rather than healing things, has only worsened the compute-memory imbalance.
In many higher-level languages, raw data comes in boxes containing metadata and a pointer to the actual data. In Python, the PyObject box holds reference counts, so that the garbage collector can operate generically on all Python entities.
Boxing creates two sources of inefficiency:
- The metadata bloats the data, reducing the data density of our expensive memory.
- The pointer indirection creates another round trip of memory retrieval latency.
A NumPy array can hold many raw data within a single PyObject box, provided that all of those data are of the same type (int32, float32, etc.). By doing this, NumPy amortizes the cost of boxing over multiple data.
In my previous investigations into Monte Carlo tree search, a naive UCT implementation performed poorly because it instantiated millions of UCTNode objects whose sole purpose was to hold a handful of float32 values. In the optimized UCT implementation, these nodes were replaced with NumPy arrays, reducing memory usage by a factor of 30.
Attribute lookup / function dispatch costs
Python’s language design forces an unusually large amount of pointer chasing. I mentioned boxing as one layer of pointer indirection, but really it’s just the tip of the iceberg.
Python has no problem handling the following code, even though each of these multiplications invokes a completely different implementation.
>>> mixed_list = [1, 1.0, 'foo', ('bar',)] >>> for obj in mixed_list: ... print(obj * 2) 2 2.0 'foofoo' ('bar', 'bar')
Python accomplishes this with a minimum of two layers of pointer indirection:
- Look up the type of the object.
- Look up and execute the
__mul__function from that type’s operation registry.
Additional layers of pointer indirection may be required if the
__mul__ method is defined on a superclass: the chain of superclasses must be traversed, one pointer at a time, until an implementation is found.
Attribute lookup is similarly fraught;
__getattribute__ provide users with flexibility that incurs pointer chasing overhead with something as simple as executing
a.b. Access patterns like
a.b.c.d create exactly the chained data access patterns that are a worst-case for data retrieval latency.
To top it all off, merely resolving the object is expensive: there’s a stack of lexical scopes (local, nonlocal, then global) that are checked in order to find the variable name. Each check requires a dictionary lookup, another source of pointer indirection.
As the saying goes: “We can solve any problem by introducing an extra level of indirection… except for the problem of too many levels of indirection”. The NumPy family of libraries deals with this indirection, not by removing it, but again by sharing its cost over multiple data.
>>> homogenous_array = np.arange(5, dtype=np.float32) >>> multiply_by_two = homogenous_array * 2 >>> print(multiply_by_two) array([ 0., 2., 4., 6., 8.], dtype=float32)
Sharing a single box for multiple data allows NumPy to retain the expressiveness of Python while minimizing the cost of the dynamism. As before, this works because of the additional constraint that all data in a NumPy array must have identical type.
The Frontier: JIT
So far, we’ve seen that NumPy doesn’t solve any of Python’s fundamental problems when it comes to pointer overhead. Instead, it merely puts a bandaid on the problem by sharing those costs across multiple data. It’s a pretty successful strategy – in my hands (1, 2), I find that NumPy can typically achieve 30-60x speedups over pure Python solutions to dense numerical code. However, given that C code typically achieves 100-200x performance over pure Python on dense numerical code (common in scientific computing), it would be nice if we could further reduce the Python overhead.
Tracing JITs promise to do exactly this. Roughly, the strategy is to trace the execution of the code and record the pointer chasing outcomes. Then, when you call the same code snippet, reuse the recorded outcomes! NumPy amortizes Python overhead over multiple data, and JIT amortizes Python overhead over multiple function calls.
(I should note that I’m most familiar with the tracing JITs used by TensorFlow and JAX. PyPy and Numba are two alternate JIT implementations that have a longer history, but I don’t know enough about them to treat them fairly, so my apologies to readers.)
Unsurprisingly, there’s a catch to all this. NumPy can amortize Python overhead over multiple data, but only if that data is the same type. JIT can amortize Python overhead over multiple function calls, but only if the function calls would have resulted in the same pointer chasing outcomes. Retracing the function to verify this would defeat the purpose of JIT, so instead, TensorFlow/JAX JIT uses array shape and dtype to guess at whether a trace is reusable. This heuristic is necessarily conservative, rules out otherwise legal programs, often requires unnecessarily specific shape information, and doesn’t make any guarantees against mischievous tinkering. Furthermore, data-dependent tracing is a known issue (1, 2). I worked on AutoGraph, a tool to address data-dependent tracing. Still, the engineering benefits of a shared tracing infrastructure are too good to pass up. I expect to see JIT-based systems flourish in the future and iron out their user experience.
The NumPy API’s specifically addresses Python’s performance problems for the kinds of programs that scientific computing users want to write. It encourages users to write code in ways that minimize pointer overhead. Coincidentally, this way of writing code is a fruitful abstraction for tracing JITs targeting vastly parallel computing architectures like GPU and TPU. (Some people argue that machine learning is stuck in a rut due to this NumPy monoculture.) In any case, tracing JITs built on top of NumPy-like APIs are flourishing, and they are by far the easiest way to access the exponentially growing compute available on the cloud. It’s a good time to be a Python programmer in machine learning!
Thanks to Alexey Radul for commenting on drafts of this essay.