AlphaGo Zero: an analysis
Originally posted 2017-12-03
Tagged: computer science, machine learning, alphago
Obligatory disclaimer: all opinions are mine and not of my employer
Where to begin?
I think the most mindblowing part of this paper is just how simple the core idea is: if you have an algorithm that can trade computation time for better assessments, and a way to learn from these better assessments, then you have a way to bootstrap yourself to a higher level. There’s a lot of deserved buzz around how this is an extremely general RL technique that can be used for anything.
I think there are two things special to Go that make this RL technique viable.
First - Go piggybacks off of the success of convolutional networks. Out of the many kinds of networks, convnets for image processing are definitely the kind of network that has clear theoretical justification and the most real-world success. Most games have a more arbitrary, less spatial/geometric sense of gameplay, and would require carefully designed network architectures. Go gets to use a vanilla convnet architecture with almost no modification.
Second - Monte Carlo Tree Search (MCTS) is the logical successor of minimax for turn-based games that have a large branching factor. MCTS has been investigated by computer Go researchers for 10 years now, and the community has had a long time to understand how MCTS behaves in favorable and unfavorable positions, and to discover algorithmic optimizations like virtual losses (more on this later). Another algorithmic breakthrough like MCTS will be needed to handle games that have a continuous time dimension.
Convnets and TPUs
What makes convnets so appropriate for Go? Well, a Go board is a reasonably large square grid, and there’s nothing particularly special about any point, other than its distance from the edge of the board. That’s exactly the kind of input that convnets do well on.
The recent development of residual network architectures has also allowed convnets to scale to ridiculous depths. The original AlphaGo used 12 layers of 3x3 convolutions, which meant that information could only have propagated a Manhattan distance of 24. To compensate, a set of handcrafted input features helped information propagate through the 19x19 board, by computing ladders and shared liberties of a long chain. But with resnets, the 40 (or 80) convolutions apparently eliminates the need for these handcrafted features. Resnets are clearly an upgrade from regular convnets, and it’s unsurprising that the new AlphaGo uses them.
The downside of convnets is the rather large amount of computation they require. Compared to a fully connected network of the same size on a 19x19 board, convnets require roughly 14500 times fewer parameters[1]. The number of add-multiplies doesn’t actually change, though, and the efficiency gains instead encourage researchers to just make their convnets bigger. The net result is a crapton of required computation. That, in turn, implies a lot of TPUs.
Fig. 1: The modern convnet
A lot of press has focused on the use of “4 TPUs” by AGZ, which is the number of TPUs that is used by the competition version of AGZ. This is the wrong thing to focus on. The important number is 2000 TPUs. This is the number of TPUs used during the self-play bootstrapping phase[2]. If you want to trade computation time for improved play, you’ll need a lot of sustained computation time! There is almost certainly room for more efficient use of computation.
[1] Each pixel in the 19x19x256 output of one conv layer only looks at a 3x3x256 slice of the previous layer, and additionally, each of the 19x19 points share the same weights. Thus, the equivalent fully connected network would have used 361^2 / 9 = 14480x more weights.
[2] This number isn’t given in the paper, but it can be extrapolated from their numbers (5 million self-play games, 250 moves per game, 0.4s thinking time per move = 500 million compute seconds = 6000 compute days, which was done in 3 real days. Therefore, something like 2000 TPUs in parallel. Aja Huang has confirmed this number.
MCTS with Virtual Losses
In addition to having a lot of TPUs, it’s important to optimize things so that the TPU is at 100% utilization. The traditional MCTS update algorithm, as described by the AGZ paper itself, goes like this:
- Pick a new variation in the game tree to explore, balancing between reading more deeply into favorable variations, and exploring new variations.
- Ask the neural network what it thinks of that position. (It will return a value estimation, and the most likely followup moves).
- Record the followup moves and value estimate. Increment the visit counts for all positions leading to that variation.
Unfortunately, the algorithm as described here is an inherently sequential process. Because MCTS is deterministic, rerunning step 1 will always return the same variation, until the updated value estimates and visit counts have been incorporated. And yet, the supporting information in the AGZ paper describes batching up positions in multiples of 8, for optimal TPU throughput.
The simplest way to achieve this sort of parallelism is to just play 8 games in parallel. However, there’s a few things that make this approach less simple than it initially seems. First is the ragged edges problem: what happens when your games don’t end after the same number of moves? Second is the latency problem: for the purposes of delivering games to the training process, you would rather have 1 game completed every minute, than 8 games completed every 8 minutes. Third is that this method of parallelism severely hamstrings your competition strength, where you want to focus all of your computation time on one game.
So we’d like to figure out how to get a TPU (or four) to operate concurrently on one game tree. That method is virtual losses, and it works as follows:
- Pick a new variation in the game tree to explore, balancing between reading more deeply into favorable variations, and exploring new variations. Increment the visit count for all positions leading to that variation.
- Ask the neural network what it thinks of that position. (It will return a value estimation, and the most likely followup moves).
- Record the followup moves and value estimate, realigning value estimates to the standard MCTS algorithm.
The method is called virtual losses, because by incrementing visit counts without adding the value estimate, you are adding a value estimate of “0” - or in other words, a (virtual) loss. The net effect is that you can now rerun the first step because it will give you a different variation each time. Therefore, you can run this step repeatedly to get a batch of 8 positions to send to the TPU, and even while the TPU is working, you can continue to run the first step to prepare the next batch.
It seems like a small detail, but virtual losses allow MCTS to scale horizontally by a factor of about 50x. 32x comes from 4 TPUs and a batch size of 8 all working on one game tree, and another 2x for allowing both CPU and TPU to work concurrently.
MCTS in the Random Regime
The second mindblowing part of the AGZ paper was that this bootstrapping actually worked, even starting from random noise. The MCTS would just be averaging a bunch of random value estimates at the start, so how would it make any progress at all?
Here’s how I think it would work:
- A bunch of completely random games are played. The value net learns that just counting black stones and white stones correlates with who wins. MCTS is also probably good enough to figure out that if black passes early, then white will also pass to end the game and win by komi, at least according to Tromp-Taylor rules. The training data thus lacks any passes, and the policy network learns not to pass.
- MCTS starts to figure out that capture is a good move, because the value net has learned to count stones -> the policy network starts to capture stones. Simultaneously, the value net starts to learn that spots surrounded by stones of one color can be considered to belong to that color.
- ???
- Beat Lee Sedol.
Joking aside, the insight here is that there are only two ways to impart true signal into this system: during MCTS (if two passes are executed), and during training of the value net. The value net will therefore be the first to learn; MCTS will then play moves that guide the game towards high value states (taking the opponent’s moves into consideration); and only then will the policy network start outputting those moves. The resulting games are slightly more complex, allowing the value network learn more sophisticated ideas. In a way, the whole thing reminds me of Fibonacci numbers: your game strength is a combination of the value network from last generation and the policy network from two generations past.
Conclusion
AlphaGo Zero’s reinforcement learning algorithm is an accomplishment that I think should have happened a decade from now, but was made possible today because of Go’s amenability to convnets and MCTS.
Miscellaneous thoughts
The first AlphaGo paper trained the policy network first, then froze those weights while training the weights of a value branch. The current AlphaGo concurrently computes both policy and value halves, and trains with a combined loss function. This is really elegant is several ways: the two objectives regularize each other; it halves the computation time required; and it integrates perfectly with MCTS, which requires evaluating both policy and value parts for each variation investigated.
The reason 400 games were played to evaluate whether a candidate was better than its predecessor, is because 400 gives statistical significance at exactly 2 standard deviations for a 55% win rate threshold. In particular, the variance of a binomial process is Np(1-p), so if the null hypothesis is p = 0.5 (the candidates do not differ), then the expected variance is 400 * 0.5 * 0.5 = 100. The standard deviation is thus +/- 10, and 200 wins +/- 10 corresponds to a 50% +/- 2.5% win rate. A 55% win rate is therefore 2 standard deviations.
There were a lot of interconnected hyperparameters, making this far from a drop-in algorithm:
- the number of moves randomly sampled at the start to generate enough variation , compared to the size of the board, the typical length of game, and the typical branching factor/entropy introduced by each move.
- the size of the board, the dispersion factor of the Dirichlet noise, the \(L_2\) regularization strength that causes the policy net’s outputs to be more disperse, the number of MCTS searches that were used to overcome the magnitude of the priors, and the \(P_{UCT}\) constant (not given in paper) used to weight the policy priors against the MCTS searches.
- the learning rate, batch size, and training speed as compared to the rate at which new games were generated.
- the depth of old games that were revisited during training, compared to the rate at which new version of AGZ were promoted.
On the subject of sampling old games - I wonder if it’s analogous to the way that strong Go players have an intuitive sense of whether a group is alive. They’re almost always right, but maybe they need to read it out every so often to keep their intuition sharp.